forked from open-webui/open-webui
		
	fix: address comment in pr #1687
This commit is contained in:
		
							parent
							
								
									d5f60b119c
								
							
						
					
					
						commit
						c9c9660459
					
				
					 4 changed files with 92 additions and 43 deletions
				
			
		|  | @ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)): | ||||||
|     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} |     return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def get_ollama_endpoint(url_idx: int = 0): |  | ||||||
|     return app.state.OLLAMA_BASE_URLS[url_idx] |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class UrlUpdateForm(BaseModel): | class UrlUpdateForm(BaseModel): | ||||||
|     urls: List[str] |     urls: List[str] | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -39,8 +39,6 @@ import json | ||||||
| 
 | 
 | ||||||
| import sentence_transformers | import sentence_transformers | ||||||
| 
 | 
 | ||||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm |  | ||||||
| 
 |  | ||||||
| from apps.web.models.documents import ( | from apps.web.models.documents import ( | ||||||
|     Documents, |     Documents, | ||||||
|     DocumentForm, |     DocumentForm, | ||||||
|  | @ -48,6 +46,7 @@ from apps.web.models.documents import ( | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| from apps.rag.utils import ( | from apps.rag.utils import ( | ||||||
|  |     get_model_path, | ||||||
|     query_embeddings_doc, |     query_embeddings_doc, | ||||||
|     query_embeddings_function, |     query_embeddings_function, | ||||||
|     query_embeddings_collection, |     query_embeddings_collection, | ||||||
|  | @ -60,6 +59,7 @@ from utils.misc import ( | ||||||
|     extract_folders_after_data_docs, |     extract_folders_after_data_docs, | ||||||
| ) | ) | ||||||
| from utils.utils import get_current_user, get_admin_user | from utils.utils import get_current_user, get_admin_user | ||||||
|  | 
 | ||||||
| from config import ( | from config import ( | ||||||
|     SRC_LOG_LEVELS, |     SRC_LOG_LEVELS, | ||||||
|     UPLOAD_DIR, |     UPLOAD_DIR, | ||||||
|  | @ -68,8 +68,10 @@ from config import ( | ||||||
|     RAG_RELEVANCE_THRESHOLD, |     RAG_RELEVANCE_THRESHOLD, | ||||||
|     RAG_EMBEDDING_ENGINE, |     RAG_EMBEDDING_ENGINE, | ||||||
|     RAG_EMBEDDING_MODEL, |     RAG_EMBEDDING_MODEL, | ||||||
|  |     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||||
|     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, |     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||||
|     RAG_RERANKING_MODEL, |     RAG_RERANKING_MODEL, | ||||||
|  |     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||||
|     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, |     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||||
|     RAG_OPENAI_API_BASE_URL, |     RAG_OPENAI_API_BASE_URL, | ||||||
|     RAG_OPENAI_API_KEY, |     RAG_OPENAI_API_KEY, | ||||||
|  | @ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||||
| 
 | 
 | ||||||
| app = FastAPI() | app = FastAPI() | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| app.state.TOP_K = RAG_TOP_K | app.state.TOP_K = RAG_TOP_K | ||||||
| app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | ||||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | app.state.CHUNK_SIZE = CHUNK_SIZE | ||||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE | app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE | ||||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||||
| app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL | app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL | ||||||
|  | @ -104,18 +104,28 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY | ||||||
| 
 | 
 | ||||||
| app.state.PDF_EXTRACT_IMAGES = False | app.state.PDF_EXTRACT_IMAGES = False | ||||||
| 
 | 
 | ||||||
| if app.state.RAG_EMBEDDING_ENGINE == "": | 
 | ||||||
|  | def update_embedding_model( | ||||||
|  |     embedding_model: str, | ||||||
|  |     update_model: bool = False, | ||||||
|  | ): | ||||||
|  |     if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": | ||||||
|         app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( |         app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||||
|         app.state.RAG_EMBEDDING_MODEL, |             get_model_path(embedding_model, update_model), | ||||||
|             device=DEVICE_TYPE, |             device=DEVICE_TYPE, | ||||||
|             trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, |             trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         app.state.sentence_transformer_ef = None |         app.state.sentence_transformer_ef = None | ||||||
| 
 | 
 | ||||||
| if not app.state.RAG_RERANKING_MODEL == "": | 
 | ||||||
|  | def update_reranking_model( | ||||||
|  |     reranking_model: str, | ||||||
|  |     update_model: bool = False, | ||||||
|  | ): | ||||||
|  |     if reranking_model: | ||||||
|         app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( |         app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||||
|         app.state.RAG_RERANKING_MODEL, |             get_model_path(reranking_model, update_model), | ||||||
|             device=DEVICE_TYPE, |             device=DEVICE_TYPE, | ||||||
|             trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, |             trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||||
|         ) |         ) | ||||||
|  | @ -123,8 +133,19 @@ else: | ||||||
|         app.state.sentence_transformer_rf = None |         app.state.sentence_transformer_rf = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | update_embedding_model( | ||||||
|  |     app.state.RAG_EMBEDDING_MODEL, | ||||||
|  |     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | update_reranking_model( | ||||||
|  |     app.state.RAG_RERANKING_MODEL, | ||||||
|  |     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| origins = ["*"] | origins = ["*"] | ||||||
| 
 | 
 | ||||||
|  | 
 | ||||||
| app.add_middleware( | app.add_middleware( | ||||||
|     CORSMiddleware, |     CORSMiddleware, | ||||||
|     allow_origins=origins, |     allow_origins=origins, | ||||||
|  | @ -200,15 +221,7 @@ async def update_embedding_config( | ||||||
|                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url |                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url | ||||||
|                 app.state.OPENAI_API_KEY = form_data.openai_config.key |                 app.state.OPENAI_API_KEY = form_data.openai_config.key | ||||||
| 
 | 
 | ||||||
|             app.state.sentence_transformer_ef = None |         update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) | ||||||
|         else: |  | ||||||
|             app.state.sentence_transformer_ef = ( |  | ||||||
|                 sentence_transformers.SentenceTransformer( |  | ||||||
|                     app.state.RAG_EMBEDDING_MODEL, |  | ||||||
|                     device=DEVICE_TYPE, |  | ||||||
|                     trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, |  | ||||||
|                 ) |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         return { |         return { | ||||||
|             "status": True, |             "status": True, | ||||||
|  | @ -219,7 +232,6 @@ async def update_embedding_config( | ||||||
|                 "key": app.state.OPENAI_API_KEY, |                 "key": app.state.OPENAI_API_KEY, | ||||||
|             }, |             }, | ||||||
|         } |         } | ||||||
| 
 |  | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         log.exception(f"Problem updating embedding model: {e}") |         log.exception(f"Problem updating embedding model: {e}") | ||||||
|         raise HTTPException( |         raise HTTPException( | ||||||
|  | @ -242,13 +254,7 @@ async def update_reranking_config( | ||||||
|     try: |     try: | ||||||
|         app.state.RAG_RERANKING_MODEL = form_data.reranking_model |         app.state.RAG_RERANKING_MODEL = form_data.reranking_model | ||||||
| 
 | 
 | ||||||
|         if app.state.RAG_RERANKING_MODEL == "": |         update_reranking_model(app.state.RAG_RERANKING_MODEL, True) | ||||||
|             app.state.sentence_transformer_rf = None |  | ||||||
|         else: |  | ||||||
|             app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( |  | ||||||
|                 app.state.RAG_RERANKING_MODEL, |  | ||||||
|                 device=DEVICE_TYPE, |  | ||||||
|             ) |  | ||||||
| 
 | 
 | ||||||
|         return { |         return { | ||||||
|             "status": True, |             "status": True, | ||||||
|  |  | ||||||
|  | @ -1,3 +1,4 @@ | ||||||
|  | import os | ||||||
| import logging | import logging | ||||||
| import requests | import requests | ||||||
| 
 | 
 | ||||||
|  | @ -8,6 +9,8 @@ from apps.ollama.main import ( | ||||||
|     GenerateEmbeddingsForm, |     GenerateEmbeddingsForm, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | from huggingface_hub import snapshot_download | ||||||
|  | 
 | ||||||
| from langchain_core.documents import Document | from langchain_core.documents import Document | ||||||
| from langchain_community.retrievers import BM25Retriever | from langchain_community.retrievers import BM25Retriever | ||||||
| from langchain.retrievers import ( | from langchain.retrievers import ( | ||||||
|  | @ -282,8 +285,6 @@ def rag_messages( | ||||||
| 
 | 
 | ||||||
|         extracted_collections.extend(collection) |         extracted_collections.extend(collection) | ||||||
| 
 | 
 | ||||||
|     log.debug(f"relevant_contexts: {relevant_contexts}") |  | ||||||
| 
 |  | ||||||
|     context_string = "" |     context_string = "" | ||||||
|     for context in relevant_contexts: |     for context in relevant_contexts: | ||||||
|         items = context["documents"][0] |         items = context["documents"][0] | ||||||
|  | @ -319,6 +320,44 @@ def rag_messages( | ||||||
|     return messages |     return messages | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | def get_model_path(model: str, update_model: bool = False): | ||||||
|  |     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | ||||||
|  |     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | ||||||
|  | 
 | ||||||
|  |     local_files_only = not update_model | ||||||
|  | 
 | ||||||
|  |     snapshot_kwargs = { | ||||||
|  |         "cache_dir": cache_dir, | ||||||
|  |         "local_files_only": local_files_only, | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     log.debug(f"embedding_model: {model}") | ||||||
|  |     log.debug(f"snapshot_kwargs: {snapshot_kwargs}") | ||||||
|  | 
 | ||||||
|  |     # Inspiration from upstream sentence_transformers | ||||||
|  |     if ( | ||||||
|  |         os.path.exists(model) | ||||||
|  |         or ("\\" in model or model.count("/") > 1) | ||||||
|  |         and local_files_only | ||||||
|  |     ): | ||||||
|  |         # If fully qualified path exists, return input, else set repo_id | ||||||
|  |         return model | ||||||
|  |     elif "/" not in model: | ||||||
|  |         # Set valid repo_id for model short-name | ||||||
|  |         model = "sentence-transformers" + "/" + model | ||||||
|  | 
 | ||||||
|  |     snapshot_kwargs["repo_id"] = model | ||||||
|  | 
 | ||||||
|  |     # Attempt to query the huggingface_hub library to determine the local path and/or to update | ||||||
|  |     try: | ||||||
|  |         model_repo_path = snapshot_download(**snapshot_kwargs) | ||||||
|  |         log.debug(f"model_repo_path: {model_repo_path}") | ||||||
|  |         return model_repo_path | ||||||
|  |     except Exception as e: | ||||||
|  |         log.exception(f"Cannot determine model snapshot path: {e}") | ||||||
|  |         return model | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def generate_openai_embeddings( | def generate_openai_embeddings( | ||||||
|     model: str, text: str, key: str, url: str = "https://api.openai.com/v1" |     model: str, text: str, key: str, url: str = "https://api.openai.com/v1" | ||||||
| ): | ): | ||||||
|  |  | ||||||
|  | @ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get( | ||||||
| ) | ) | ||||||
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | ||||||
| 
 | 
 | ||||||
|  | RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | ||||||
|  |     os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | ||||||
|     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" |     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||||
| ) | ) | ||||||
|  | @ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") | ||||||
| if not RAG_RERANKING_MODEL == "": | if not RAG_RERANKING_MODEL == "": | ||||||
|     log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), |     log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), | ||||||
| 
 | 
 | ||||||
|  | RAG_RERANKING_MODEL_AUTO_UPDATE = ( | ||||||
|  |     os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | ||||||
|     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" |     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||||
| ) | ) | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Steven Kreitzer
						Steven Kreitzer