forked from open-webui/open-webui
		
	fix: address comment in pr #1687
This commit is contained in:
		
							parent
							
								
									d5f60b119c
								
							
						
					
					
						commit
						c9c9660459
					
				
					 4 changed files with 92 additions and 43 deletions
				
			
		|  | @ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)): | |||
|     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} | ||||
| 
 | ||||
| 
 | ||||
| def get_ollama_endpoint(url_idx: int = 0): | ||||
|     return app.state.OLLAMA_BASE_URLS[url_idx] | ||||
| 
 | ||||
| 
 | ||||
| class UrlUpdateForm(BaseModel): | ||||
|     urls: List[str] | ||||
| 
 | ||||
|  |  | |||
|  | @ -39,8 +39,6 @@ import json | |||
| 
 | ||||
| import sentence_transformers | ||||
| 
 | ||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | ||||
| 
 | ||||
| from apps.web.models.documents import ( | ||||
|     Documents, | ||||
|     DocumentForm, | ||||
|  | @ -48,6 +46,7 @@ from apps.web.models.documents import ( | |||
| ) | ||||
| 
 | ||||
| from apps.rag.utils import ( | ||||
|     get_model_path, | ||||
|     query_embeddings_doc, | ||||
|     query_embeddings_function, | ||||
|     query_embeddings_collection, | ||||
|  | @ -60,6 +59,7 @@ from utils.misc import ( | |||
|     extract_folders_after_data_docs, | ||||
| ) | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| 
 | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     UPLOAD_DIR, | ||||
|  | @ -68,8 +68,10 @@ from config import ( | |||
|     RAG_RELEVANCE_THRESHOLD, | ||||
|     RAG_EMBEDDING_ENGINE, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
|     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|     RAG_RERANKING_MODEL, | ||||
|     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||
|     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||
|     RAG_OPENAI_API_BASE_URL, | ||||
|     RAG_OPENAI_API_KEY, | ||||
|  | @ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | |||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| 
 | ||||
| app.state.TOP_K = RAG_TOP_K | ||||
| app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | ||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | ||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| 
 | ||||
| 
 | ||||
| app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL | ||||
|  | @ -104,18 +104,28 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY | |||
| 
 | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
| 
 | ||||
| if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
| 
 | ||||
| def update_embedding_model( | ||||
|     embedding_model: str, | ||||
|     update_model: bool = False, | ||||
| ): | ||||
|     if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|         app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||
|         app.state.RAG_EMBEDDING_MODEL, | ||||
|             get_model_path(embedding_model, update_model), | ||||
|             device=DEVICE_TYPE, | ||||
|             trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|         ) | ||||
|     else: | ||||
|         app.state.sentence_transformer_ef = None | ||||
| 
 | ||||
| if not app.state.RAG_RERANKING_MODEL == "": | ||||
| 
 | ||||
| def update_reranking_model( | ||||
|     reranking_model: str, | ||||
|     update_model: bool = False, | ||||
| ): | ||||
|     if reranking_model: | ||||
|         app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|         app.state.RAG_RERANKING_MODEL, | ||||
|             get_model_path(reranking_model, update_model), | ||||
|             device=DEVICE_TYPE, | ||||
|             trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||
|         ) | ||||
|  | @ -123,8 +133,19 @@ else: | |||
|         app.state.sentence_transformer_rf = None | ||||
| 
 | ||||
| 
 | ||||
| update_embedding_model( | ||||
|     app.state.RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
| ) | ||||
| 
 | ||||
| update_reranking_model( | ||||
|     app.state.RAG_RERANKING_MODEL, | ||||
|     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||
| ) | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=origins, | ||||
|  | @ -200,15 +221,7 @@ async def update_embedding_config( | |||
|                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url | ||||
|                 app.state.OPENAI_API_KEY = form_data.openai_config.key | ||||
| 
 | ||||
|             app.state.sentence_transformer_ef = None | ||||
|         else: | ||||
|             app.state.sentence_transformer_ef = ( | ||||
|                 sentence_transformers.SentenceTransformer( | ||||
|                     app.state.RAG_EMBEDDING_MODEL, | ||||
|                     device=DEVICE_TYPE, | ||||
|                     trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|                 ) | ||||
|             ) | ||||
|         update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|  | @ -219,7 +232,6 @@ async def update_embedding_config( | |||
|                 "key": app.state.OPENAI_API_KEY, | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(f"Problem updating embedding model: {e}") | ||||
|         raise HTTPException( | ||||
|  | @ -242,13 +254,7 @@ async def update_reranking_config( | |||
|     try: | ||||
|         app.state.RAG_RERANKING_MODEL = form_data.reranking_model | ||||
| 
 | ||||
|         if app.state.RAG_RERANKING_MODEL == "": | ||||
|             app.state.sentence_transformer_rf = None | ||||
|         else: | ||||
|             app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|                 app.state.RAG_RERANKING_MODEL, | ||||
|                 device=DEVICE_TYPE, | ||||
|             ) | ||||
|         update_reranking_model(app.state.RAG_RERANKING_MODEL, True) | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|  |  | |||
|  | @ -1,3 +1,4 @@ | |||
| import os | ||||
| import logging | ||||
| import requests | ||||
| 
 | ||||
|  | @ -8,6 +9,8 @@ from apps.ollama.main import ( | |||
|     GenerateEmbeddingsForm, | ||||
| ) | ||||
| 
 | ||||
| from huggingface_hub import snapshot_download | ||||
| 
 | ||||
| from langchain_core.documents import Document | ||||
| from langchain_community.retrievers import BM25Retriever | ||||
| from langchain.retrievers import ( | ||||
|  | @ -282,8 +285,6 @@ def rag_messages( | |||
| 
 | ||||
|         extracted_collections.extend(collection) | ||||
| 
 | ||||
|     log.debug(f"relevant_contexts: {relevant_contexts}") | ||||
| 
 | ||||
|     context_string = "" | ||||
|     for context in relevant_contexts: | ||||
|         items = context["documents"][0] | ||||
|  | @ -319,6 +320,44 @@ def rag_messages( | |||
|     return messages | ||||
| 
 | ||||
| 
 | ||||
| def get_model_path(model: str, update_model: bool = False): | ||||
|     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | ||||
|     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | ||||
| 
 | ||||
|     local_files_only = not update_model | ||||
| 
 | ||||
|     snapshot_kwargs = { | ||||
|         "cache_dir": cache_dir, | ||||
|         "local_files_only": local_files_only, | ||||
|     } | ||||
| 
 | ||||
|     log.debug(f"embedding_model: {model}") | ||||
|     log.debug(f"snapshot_kwargs: {snapshot_kwargs}") | ||||
| 
 | ||||
|     # Inspiration from upstream sentence_transformers | ||||
|     if ( | ||||
|         os.path.exists(model) | ||||
|         or ("\\" in model or model.count("/") > 1) | ||||
|         and local_files_only | ||||
|     ): | ||||
|         # If fully qualified path exists, return input, else set repo_id | ||||
|         return model | ||||
|     elif "/" not in model: | ||||
|         # Set valid repo_id for model short-name | ||||
|         model = "sentence-transformers" + "/" + model | ||||
| 
 | ||||
|     snapshot_kwargs["repo_id"] = model | ||||
| 
 | ||||
|     # Attempt to query the huggingface_hub library to determine the local path and/or to update | ||||
|     try: | ||||
|         model_repo_path = snapshot_download(**snapshot_kwargs) | ||||
|         log.debug(f"model_repo_path: {model_repo_path}") | ||||
|         return model_repo_path | ||||
|     except Exception as e: | ||||
|         log.exception(f"Cannot determine model snapshot path: {e}") | ||||
|         return model | ||||
| 
 | ||||
| 
 | ||||
| def generate_openai_embeddings( | ||||
|     model: str, text: str, key: str, url: str = "https://api.openai.com/v1" | ||||
| ): | ||||
|  |  | |||
|  | @ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get( | |||
| ) | ||||
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | ||||
| 
 | ||||
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | ||||
|     os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | ||||
|     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||
| ) | ||||
|  | @ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") | |||
| if not RAG_RERANKING_MODEL == "": | ||||
|     log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), | ||||
| 
 | ||||
| RAG_RERANKING_MODEL_AUTO_UPDATE = ( | ||||
|     os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | ||||
|     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||
| ) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Steven Kreitzer
						Steven Kreitzer