forked from open-webui/open-webui
		
	Merge pull request #1693 from buroa/buroa/hybrid-search
feat: hybrid search with reranking
This commit is contained in:
		
						commit
						5ee2f1729a
					
				
					 8 changed files with 655 additions and 176 deletions
				
			
		|  | @ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. | |||
| The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), | ||||
| and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). | ||||
| 
 | ||||
| ## [0.1.122] - 2024-04-24 | ||||
| 
 | ||||
| - **🌟 Enhanced RAG Pipeline**: Added hybrid searching with `BM25`, reranking using `CrossEncoder`, and relevance score thresholds. | ||||
| 
 | ||||
| ## [0.1.121] - 2024-04-24 | ||||
| 
 | ||||
| ### Fixed | ||||
|  |  | |||
							
								
								
									
										12
									
								
								Dockerfile
									
										
									
									
									
								
							
							
						
						
									
										12
									
								
								Dockerfile
									
										
									
									
									
								
							|  | @ -8,8 +8,9 @@ ARG USE_CUDA_VER=cu121 | |||
| # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers | ||||
| # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard  | ||||
| # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) | ||||
| # IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. | ||||
| # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. | ||||
| ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 | ||||
| ARG USE_RERANKING_MODEL="" | ||||
| 
 | ||||
| ######## WebUI frontend ######## | ||||
| FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build | ||||
|  | @ -30,6 +31,7 @@ ARG USE_CUDA | |||
| ARG USE_OLLAMA | ||||
| ARG USE_CUDA_VER | ||||
| ARG USE_EMBEDDING_MODEL | ||||
| ARG USE_RERANKING_MODEL | ||||
| 
 | ||||
| ## Basis ## | ||||
| ENV ENV=prod \ | ||||
|  | @ -38,7 +40,8 @@ ENV ENV=prod \ | |||
|     USE_OLLAMA_DOCKER=${USE_OLLAMA} \ | ||||
|     USE_CUDA_DOCKER=${USE_CUDA} \ | ||||
|     USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ | ||||
|     USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} | ||||
|     USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \ | ||||
|     USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL} | ||||
| 
 | ||||
| ## Basis URL Config ## | ||||
| ENV OLLAMA_BASE_URL="/ollama" \ | ||||
|  | @ -62,8 +65,11 @@ ENV WHISPER_MODEL="base" \ | |||
| 
 | ||||
| ## RAG Embedding model settings ## | ||||
| ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ | ||||
|     RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \ | ||||
|     RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \ | ||||
|     SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" | ||||
| 
 | ||||
| ## Hugging Face download cache ## | ||||
| ENV HF_HOME="/app/backend/data/cache/embedding/models" | ||||
| #### Other models ########################################################## | ||||
| 
 | ||||
| WORKDIR /app/backend | ||||
|  |  | |||
|  | @ -39,8 +39,6 @@ import json | |||
| 
 | ||||
| import sentence_transformers | ||||
| 
 | ||||
| from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm | ||||
| 
 | ||||
| from apps.web.models.documents import ( | ||||
|     Documents, | ||||
|     DocumentForm, | ||||
|  | @ -48,9 +46,10 @@ from apps.web.models.documents import ( | |||
| ) | ||||
| 
 | ||||
| from apps.rag.utils import ( | ||||
|     get_model_path, | ||||
|     query_embeddings_doc, | ||||
|     query_embeddings_function, | ||||
|     query_embeddings_collection, | ||||
|     generate_openai_embeddings, | ||||
| ) | ||||
| 
 | ||||
| from utils.misc import ( | ||||
|  | @ -60,13 +59,20 @@ from utils.misc import ( | |||
|     extract_folders_after_data_docs, | ||||
| ) | ||||
| from utils.utils import get_current_user, get_admin_user | ||||
| 
 | ||||
| from config import ( | ||||
|     SRC_LOG_LEVELS, | ||||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     RAG_TOP_K, | ||||
|     RAG_RELEVANCE_THRESHOLD, | ||||
|     RAG_EMBEDDING_ENGINE, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
|     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|     RAG_RERANKING_MODEL, | ||||
|     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||
|     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||
|     RAG_OPENAI_API_BASE_URL, | ||||
|     RAG_OPENAI_API_KEY, | ||||
|     DEVICE_TYPE, | ||||
|  | @ -83,14 +89,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) | |||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| 
 | ||||
| app.state.TOP_K = 4 | ||||
| app.state.TOP_K = RAG_TOP_K | ||||
| app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | ||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | ||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| 
 | ||||
| 
 | ||||
| app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| 
 | ||||
| app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL | ||||
|  | @ -98,16 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY | |||
| 
 | ||||
| app.state.PDF_EXTRACT_IMAGES = False | ||||
| 
 | ||||
| if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
| 
 | ||||
| def update_embedding_model( | ||||
|     embedding_model: str, | ||||
|     update_model: bool = False, | ||||
| ): | ||||
|     if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|         app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||
|         app.state.RAG_EMBEDDING_MODEL, | ||||
|             get_model_path(embedding_model, update_model), | ||||
|             device=DEVICE_TYPE, | ||||
|             trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|         ) | ||||
|     else: | ||||
|         app.state.sentence_transformer_ef = None | ||||
| 
 | ||||
| 
 | ||||
| def update_reranking_model( | ||||
|     reranking_model: str, | ||||
|     update_model: bool = False, | ||||
| ): | ||||
|     if reranking_model: | ||||
|         app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder( | ||||
|             get_model_path(reranking_model, update_model), | ||||
|             device=DEVICE_TYPE, | ||||
|             trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||
|         ) | ||||
|     else: | ||||
|         app.state.sentence_transformer_rf = None | ||||
| 
 | ||||
| 
 | ||||
| update_embedding_model( | ||||
|     app.state.RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||
| ) | ||||
| 
 | ||||
| update_reranking_model( | ||||
|     app.state.RAG_RERANKING_MODEL, | ||||
|     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||
| ) | ||||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=origins, | ||||
|  | @ -134,6 +172,7 @@ async def get_status(): | |||
|         "template": app.state.RAG_TEMPLATE, | ||||
|         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, | ||||
|         "embedding_model": app.state.RAG_EMBEDDING_MODEL, | ||||
|         "reranking_model": app.state.RAG_RERANKING_MODEL, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|  | @ -150,6 +189,11 @@ async def get_embedding_config(user=Depends(get_admin_user)): | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/reranking") | ||||
| async def get_reraanking_config(user=Depends(get_admin_user)): | ||||
|     return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} | ||||
| 
 | ||||
| 
 | ||||
| class OpenAIConfigForm(BaseModel): | ||||
|     url: str | ||||
|     key: str | ||||
|  | @ -170,22 +214,14 @@ async def update_embedding_config( | |||
|     ) | ||||
|     try: | ||||
|         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine | ||||
|         app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
| 
 | ||||
|         if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: | ||||
|             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|             app.state.sentence_transformer_ef = None | ||||
| 
 | ||||
|             if form_data.openai_config != None: | ||||
|                 app.state.OPENAI_API_BASE_URL = form_data.openai_config.url | ||||
|                 app.state.OPENAI_API_KEY = form_data.openai_config.key | ||||
|         else: | ||||
|             sentence_transformer_ef = sentence_transformers.SentenceTransformer( | ||||
|                 app.state.RAG_EMBEDDING_MODEL, | ||||
|                 device=DEVICE_TYPE, | ||||
|                 trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||
|             ) | ||||
|             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model | ||||
|             app.state.sentence_transformer_ef = sentence_transformer_ef | ||||
| 
 | ||||
|         update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|  | @ -196,7 +232,6 @@ async def update_embedding_config( | |||
|                 "key": app.state.OPENAI_API_KEY, | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(f"Problem updating embedding model: {e}") | ||||
|         raise HTTPException( | ||||
|  | @ -205,6 +240,34 @@ async def update_embedding_config( | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class RerankingModelUpdateForm(BaseModel): | ||||
|     reranking_model: str | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/reranking/update") | ||||
| async def update_reranking_config( | ||||
|     form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     log.info( | ||||
|         f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" | ||||
|     ) | ||||
|     try: | ||||
|         app.state.RAG_RERANKING_MODEL = form_data.reranking_model | ||||
| 
 | ||||
|         update_reranking_model(app.state.RAG_RERANKING_MODEL, True) | ||||
| 
 | ||||
|         return { | ||||
|             "status": True, | ||||
|             "reranking_model": app.state.RAG_RERANKING_MODEL, | ||||
|         } | ||||
|     except Exception as e: | ||||
|         log.exception(f"Problem updating reranking model: {e}") | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||||
|             detail=ERROR_MESSAGES.DEFAULT(e), | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/config") | ||||
| async def get_rag_config(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|  | @ -257,11 +320,13 @@ async def get_query_settings(user=Depends(get_admin_user)): | |||
|         "status": True, | ||||
|         "template": app.state.RAG_TEMPLATE, | ||||
|         "k": app.state.TOP_K, | ||||
|         "r": app.state.RELEVANCE_THRESHOLD, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class QuerySettingsForm(BaseModel): | ||||
|     k: Optional[int] = None | ||||
|     r: Optional[float] = None | ||||
|     template: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
|  | @ -271,6 +336,7 @@ async def update_query_settings( | |||
| ): | ||||
|     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE | ||||
|     app.state.TOP_K = form_data.k if form_data.k else 4 | ||||
|     app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 | ||||
|     return {"status": True, "template": app.state.RAG_TEMPLATE} | ||||
| 
 | ||||
| 
 | ||||
|  | @ -278,6 +344,7 @@ class QueryDocForm(BaseModel): | |||
|     collection_name: str | ||||
|     query: str | ||||
|     k: Optional[int] = None | ||||
|     r: Optional[float] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/doc") | ||||
|  | @ -286,34 +353,22 @@ def query_doc_handler( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|             query_embeddings = app.state.sentence_transformer_ef.encode( | ||||
|                 form_data.query | ||||
|             ).tolist() | ||||
|         elif app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             query_embeddings = generate_ollama_embeddings( | ||||
|                 GenerateEmbeddingsForm( | ||||
|                     **{ | ||||
|                         "model": app.state.RAG_EMBEDDING_MODEL, | ||||
|                         "prompt": form_data.query, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
|         elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||
|             query_embeddings = generate_openai_embeddings( | ||||
|                 model=app.state.RAG_EMBEDDING_MODEL, | ||||
|                 text=form_data.query, | ||||
|                 key=app.state.OPENAI_API_KEY, | ||||
|                 url=app.state.OPENAI_API_BASE_URL, | ||||
|         embeddings_function = query_embeddings_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|             app.state.OPENAI_API_KEY, | ||||
|             app.state.OPENAI_API_BASE_URL, | ||||
|         ) | ||||
| 
 | ||||
|         return query_embeddings_doc( | ||||
|             collection_name=form_data.collection_name, | ||||
|             query=form_data.query, | ||||
|             query_embeddings=query_embeddings, | ||||
|             k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=app.state.sentence_transformer_rf, | ||||
|         ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|  | @ -326,6 +381,7 @@ class QueryCollectionsForm(BaseModel): | |||
|     collection_names: List[str] | ||||
|     query: str | ||||
|     k: Optional[int] = None | ||||
|     r: Optional[float] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/collection") | ||||
|  | @ -334,33 +390,22 @@ def query_collection_handler( | |||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|             query_embeddings = app.state.sentence_transformer_ef.encode( | ||||
|                 form_data.query | ||||
|             ).tolist() | ||||
|         elif app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             query_embeddings = generate_ollama_embeddings( | ||||
|                 GenerateEmbeddingsForm( | ||||
|                     **{ | ||||
|                         "model": app.state.RAG_EMBEDDING_MODEL, | ||||
|                         "prompt": form_data.query, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
|         elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||
|             query_embeddings = generate_openai_embeddings( | ||||
|                 model=app.state.RAG_EMBEDDING_MODEL, | ||||
|                 text=form_data.query, | ||||
|                 key=app.state.OPENAI_API_KEY, | ||||
|                 url=app.state.OPENAI_API_BASE_URL, | ||||
|         embeddings_function = query_embeddings_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|             app.state.OPENAI_API_KEY, | ||||
|             app.state.OPENAI_API_BASE_URL, | ||||
|         ) | ||||
| 
 | ||||
|         return query_embeddings_collection( | ||||
|             collection_names=form_data.collection_names, | ||||
|             query_embeddings=query_embeddings, | ||||
|             query=form_data.query, | ||||
|             k=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=app.state.sentence_transformer_rf, | ||||
|         ) | ||||
| 
 | ||||
|     except Exception as e: | ||||
|         log.exception(e) | ||||
|         raise HTTPException( | ||||
|  | @ -427,8 +472,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
|     log.info(f"store_docs_in_vector_db {docs} {collection_name}") | ||||
| 
 | ||||
|     texts = [doc.page_content for doc in docs] | ||||
|     texts = list(map(lambda x: x.replace("\n", " "), texts)) | ||||
| 
 | ||||
|     metadatas = [doc.metadata for doc in docs] | ||||
| 
 | ||||
|     try: | ||||
|  | @ -440,27 +483,16 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b | |||
| 
 | ||||
|         collection = CHROMA_CLIENT.create_collection(name=collection_name) | ||||
| 
 | ||||
|         if app.state.RAG_EMBEDDING_ENGINE == "": | ||||
|             embeddings = app.state.sentence_transformer_ef.encode(texts).tolist() | ||||
|         elif app.state.RAG_EMBEDDING_ENGINE == "ollama": | ||||
|             embeddings = [ | ||||
|                 generate_ollama_embeddings( | ||||
|                     GenerateEmbeddingsForm( | ||||
|                         **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} | ||||
|         embedding_func = query_embeddings_function( | ||||
|             app.state.RAG_EMBEDDING_ENGINE, | ||||
|             app.state.RAG_EMBEDDING_MODEL, | ||||
|             app.state.sentence_transformer_ef, | ||||
|             app.state.OPENAI_API_KEY, | ||||
|             app.state.OPENAI_API_BASE_URL, | ||||
|         ) | ||||
|                 ) | ||||
|                 for text in texts | ||||
|             ] | ||||
|         elif app.state.RAG_EMBEDDING_ENGINE == "openai": | ||||
|             embeddings = [ | ||||
|                 generate_openai_embeddings( | ||||
|                     model=app.state.RAG_EMBEDDING_MODEL, | ||||
|                     text=text, | ||||
|                     key=app.state.OPENAI_API_KEY, | ||||
|                     url=app.state.OPENAI_API_BASE_URL, | ||||
|                 ) | ||||
|                 for text in texts | ||||
|             ] | ||||
| 
 | ||||
|         embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) | ||||
|         embeddings = embedding_func(embedding_texts) | ||||
| 
 | ||||
|         for batch in create_batches( | ||||
|             api=CHROMA_CLIENT, | ||||
|  |  | |||
|  | @ -1,3 +1,4 @@ | |||
| import os | ||||
| import logging | ||||
| import requests | ||||
| 
 | ||||
|  | @ -8,6 +9,15 @@ from apps.ollama.main import ( | |||
|     GenerateEmbeddingsForm, | ||||
| ) | ||||
| 
 | ||||
| from huggingface_hub import snapshot_download | ||||
| 
 | ||||
| from langchain_core.documents import Document | ||||
| from langchain_community.retrievers import BM25Retriever | ||||
| from langchain.retrievers import ( | ||||
|     ContextualCompressionRetriever, | ||||
|     EnsembleRetriever, | ||||
| ) | ||||
| 
 | ||||
| from config import SRC_LOG_LEVELS, CHROMA_CLIENT | ||||
| 
 | ||||
| 
 | ||||
|  | @ -15,18 +25,53 @@ log = logging.getLogger(__name__) | |||
| log.setLevel(SRC_LOG_LEVELS["RAG"]) | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int): | ||||
| def query_embeddings_doc( | ||||
|     collection_name: str, | ||||
|     query: str, | ||||
|     k: int, | ||||
|     r: float, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
| ): | ||||
|     try: | ||||
|         # if you use docker use the model from the environment variable | ||||
|         log.info(f"query_embeddings_doc {query_embeddings}") | ||||
|         collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||
| 
 | ||||
|         result = collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|             n_results=k, | ||||
|         documents = collection.get()  # get all documents | ||||
|         bm25_retriever = BM25Retriever.from_texts( | ||||
|             texts=documents.get("documents"), | ||||
|             metadatas=documents.get("metadatas"), | ||||
|         ) | ||||
|         bm25_retriever.k = k | ||||
| 
 | ||||
|         chroma_retriever = ChromaRetriever( | ||||
|             collection=collection, | ||||
|             embeddings_function=embeddings_function, | ||||
|             top_n=k, | ||||
|         ) | ||||
| 
 | ||||
|         log.info(f"query_embeddings_doc:result {result}") | ||||
|         ensemble_retriever = EnsembleRetriever( | ||||
|             retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5] | ||||
|         ) | ||||
| 
 | ||||
|         compressor = RerankCompressor( | ||||
|             embeddings_function=embeddings_function, | ||||
|             reranking_function=reranking_function, | ||||
|             r_score=r, | ||||
|             top_n=k, | ||||
|         ) | ||||
| 
 | ||||
|         compression_retriever = ContextualCompressionRetriever( | ||||
|             base_compressor=compressor, base_retriever=ensemble_retriever | ||||
|         ) | ||||
| 
 | ||||
|         result = compression_retriever.invoke(query) | ||||
|         result = { | ||||
|             "distances": [[d.metadata.get("score") for d in result]], | ||||
|             "documents": [[d.page_content for d in result]], | ||||
|             "metadatas": [[d.metadata for d in result]], | ||||
|         } | ||||
| 
 | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         raise e | ||||
|  | @ -34,63 +79,65 @@ def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: | |||
| 
 | ||||
| def merge_and_sort_query_results(query_results, k): | ||||
|     # Initialize lists to store combined data | ||||
|     combined_ids = [] | ||||
|     combined_distances = [] | ||||
|     combined_metadatas = [] | ||||
|     combined_documents = [] | ||||
|     combined_metadatas = [] | ||||
| 
 | ||||
|     # Combine data from each dictionary | ||||
|     for data in query_results: | ||||
|         combined_ids.extend(data["ids"][0]) | ||||
|         combined_distances.extend(data["distances"][0]) | ||||
|         combined_metadatas.extend(data["metadatas"][0]) | ||||
|         combined_documents.extend(data["documents"][0]) | ||||
|         combined_metadatas.extend(data["metadatas"][0]) | ||||
| 
 | ||||
|     # Create a list of tuples (distance, id, metadata, document) | ||||
|     combined = list( | ||||
|         zip(combined_distances, combined_ids, combined_metadatas, combined_documents) | ||||
|     ) | ||||
|     # Create a list of tuples (distance, document, metadata) | ||||
|     combined = list(zip(combined_distances, combined_documents, combined_metadatas)) | ||||
| 
 | ||||
|     # Sort the list based on distances | ||||
|     combined.sort(key=lambda x: x[0]) | ||||
| 
 | ||||
|     # We don't have anything :-( | ||||
|     if not combined: | ||||
|         sorted_distances = [] | ||||
|         sorted_documents = [] | ||||
|         sorted_metadatas = [] | ||||
|     else: | ||||
|         # Unzip the sorted list | ||||
|     sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) | ||||
|         sorted_distances, sorted_documents, sorted_metadatas = zip(*combined) | ||||
| 
 | ||||
|         # Slicing the lists to include only k elements | ||||
|         sorted_distances = list(sorted_distances)[:k] | ||||
|     sorted_ids = list(sorted_ids)[:k] | ||||
|     sorted_metadatas = list(sorted_metadatas)[:k] | ||||
|         sorted_documents = list(sorted_documents)[:k] | ||||
|         sorted_metadatas = list(sorted_metadatas)[:k] | ||||
| 
 | ||||
|     # Create the output dictionary | ||||
|     merged_query_results = { | ||||
|         "ids": [sorted_ids], | ||||
|     result = { | ||||
|         "distances": [sorted_distances], | ||||
|         "metadatas": [sorted_metadatas], | ||||
|         "documents": [sorted_documents], | ||||
|         "embeddings": None, | ||||
|         "uris": None, | ||||
|         "data": None, | ||||
|         "metadatas": [sorted_metadatas], | ||||
|     } | ||||
| 
 | ||||
|     return merged_query_results | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def query_embeddings_collection( | ||||
|     collection_names: List[str], query: str, query_embeddings, k: int | ||||
|     collection_names: List[str], | ||||
|     query: str, | ||||
|     k: int, | ||||
|     r: float, | ||||
|     embeddings_function, | ||||
|     reranking_function, | ||||
| ): | ||||
| 
 | ||||
|     results = [] | ||||
|     log.info(f"query_embeddings_collection {query_embeddings}") | ||||
| 
 | ||||
|     for collection_name in collection_names: | ||||
|         try: | ||||
|             result = query_embeddings_doc( | ||||
|                 collection_name=collection_name, | ||||
|                 query=query, | ||||
|                 query_embeddings=query_embeddings, | ||||
|                 k=k, | ||||
|                 r=r, | ||||
|                 embeddings_function=embeddings_function, | ||||
|                 reranking_function=reranking_function, | ||||
|             ) | ||||
|             results.append(result) | ||||
|         except: | ||||
|  | @ -105,19 +152,57 @@ def rag_template(template: str, context: str, query: str): | |||
|     return template | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages( | ||||
|     docs, | ||||
|     messages, | ||||
|     template, | ||||
|     k, | ||||
| def query_embeddings_function( | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|     openai_key, | ||||
|     openai_url, | ||||
| ): | ||||
|     if embedding_engine == "": | ||||
|         return lambda query: embedding_function.encode(query).tolist() | ||||
|     elif embedding_engine in ["ollama", "openai"]: | ||||
|         if embedding_engine == "ollama": | ||||
|             func = lambda query: generate_ollama_embeddings( | ||||
|                 GenerateEmbeddingsForm( | ||||
|                     **{ | ||||
|                         "model": embedding_model, | ||||
|                         "prompt": query, | ||||
|                     } | ||||
|                 ) | ||||
|             ) | ||||
|         elif embedding_engine == "openai": | ||||
|             func = lambda query: generate_openai_embeddings( | ||||
|                 model=embedding_model, | ||||
|                 text=query, | ||||
|                 key=openai_key, | ||||
|                 url=openai_url, | ||||
|             ) | ||||
| 
 | ||||
|         def generate_multiple(query, f): | ||||
|             if isinstance(query, list): | ||||
|                 return [f(q) for q in query] | ||||
|             else: | ||||
|                 return f(query) | ||||
| 
 | ||||
|         return lambda query: generate_multiple(query, func) | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages( | ||||
|     docs, | ||||
|     messages, | ||||
|     template, | ||||
|     k, | ||||
|     r, | ||||
|     embedding_engine, | ||||
|     embedding_model, | ||||
|     embedding_function, | ||||
|     reranking_function, | ||||
|     openai_key, | ||||
|     openai_url, | ||||
| ): | ||||
|     log.debug( | ||||
|         f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}" | ||||
|         f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}" | ||||
|     ) | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|  | @ -145,62 +230,66 @@ def rag_messages( | |||
|         content_type = None | ||||
|         query = "" | ||||
| 
 | ||||
|     embeddings_function = query_embeddings_function( | ||||
|         embedding_engine, | ||||
|         embedding_model, | ||||
|         embedding_function, | ||||
|         openai_key, | ||||
|         openai_url, | ||||
|     ) | ||||
| 
 | ||||
|     extracted_collections = [] | ||||
|     relevant_contexts = [] | ||||
| 
 | ||||
|     for doc in docs: | ||||
|         context = None | ||||
| 
 | ||||
|         try: | ||||
|         collection = doc.get("collection_name") | ||||
|         if collection: | ||||
|             collection = [collection] | ||||
|         else: | ||||
|             collection = doc.get("collection_names", []) | ||||
| 
 | ||||
|         collection = set(collection).difference(extracted_collections) | ||||
|         if not collection: | ||||
|             log.debug(f"skipping {doc} as it has already been extracted") | ||||
|             continue | ||||
| 
 | ||||
|         try: | ||||
|             if doc["type"] == "text": | ||||
|                 context = doc["content"] | ||||
|             else: | ||||
|                 if embedding_engine == "": | ||||
|                     query_embeddings = embedding_function.encode(query).tolist() | ||||
|                 elif embedding_engine == "ollama": | ||||
|                     query_embeddings = generate_ollama_embeddings( | ||||
|                         GenerateEmbeddingsForm( | ||||
|                             **{ | ||||
|                                 "model": embedding_model, | ||||
|                                 "prompt": query, | ||||
|                             } | ||||
|                         ) | ||||
|                     ) | ||||
|                 elif embedding_engine == "openai": | ||||
|                     query_embeddings = generate_openai_embeddings( | ||||
|                         model=embedding_model, | ||||
|                         text=query, | ||||
|                         key=openai_key, | ||||
|                         url=openai_url, | ||||
|                     ) | ||||
| 
 | ||||
|                 if doc["type"] == "collection": | ||||
|             elif doc["type"] == "collection": | ||||
|                 context = query_embeddings_collection( | ||||
|                     collection_names=doc["collection_names"], | ||||
|                     query=query, | ||||
|                         query_embeddings=query_embeddings, | ||||
|                     k=k, | ||||
|                     r=r, | ||||
|                     embeddings_function=embeddings_function, | ||||
|                     reranking_function=reranking_function, | ||||
|                 ) | ||||
|             else: | ||||
|                 context = query_embeddings_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|                     query=query, | ||||
|                         query_embeddings=query_embeddings, | ||||
|                     k=k, | ||||
|                     r=r, | ||||
|                     embeddings_function=embeddings_function, | ||||
|                     reranking_function=reranking_function, | ||||
|                 ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             log.exception(e) | ||||
|             context = None | ||||
| 
 | ||||
|         if context: | ||||
|             relevant_contexts.append(context) | ||||
| 
 | ||||
|     log.debug(f"relevant_contexts: {relevant_contexts}") | ||||
|         extracted_collections.extend(collection) | ||||
| 
 | ||||
|     context_string = "" | ||||
|     for context in relevant_contexts: | ||||
|         if context: | ||||
|             context_string += " ".join(context["documents"][0]) + "\n" | ||||
|         items = context["documents"][0] | ||||
|         context_string += "\n\n".join(items) | ||||
|     context_string = context_string.strip() | ||||
| 
 | ||||
|     ra_content = rag_template( | ||||
|         template=template, | ||||
|  | @ -208,6 +297,8 @@ def rag_messages( | |||
|         query=query, | ||||
|     ) | ||||
| 
 | ||||
|     log.debug(f"ra_content: {ra_content}") | ||||
| 
 | ||||
|     if content_type == "list": | ||||
|         new_content = [] | ||||
|         for content_item in user_message["content"]: | ||||
|  | @ -229,6 +320,44 @@ def rag_messages( | |||
|     return messages | ||||
| 
 | ||||
| 
 | ||||
| def get_model_path(model: str, update_model: bool = False): | ||||
|     # Construct huggingface_hub kwargs with local_files_only to return the snapshot path | ||||
|     cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME") | ||||
| 
 | ||||
|     local_files_only = not update_model | ||||
| 
 | ||||
|     snapshot_kwargs = { | ||||
|         "cache_dir": cache_dir, | ||||
|         "local_files_only": local_files_only, | ||||
|     } | ||||
| 
 | ||||
|     log.debug(f"model: {model}") | ||||
|     log.debug(f"snapshot_kwargs: {snapshot_kwargs}") | ||||
| 
 | ||||
|     # Inspiration from upstream sentence_transformers | ||||
|     if ( | ||||
|         os.path.exists(model) | ||||
|         or ("\\" in model or model.count("/") > 1) | ||||
|         and local_files_only | ||||
|     ): | ||||
|         # If fully qualified path exists, return input, else set repo_id | ||||
|         return model | ||||
|     elif "/" not in model: | ||||
|         # Set valid repo_id for model short-name | ||||
|         model = "sentence-transformers" + "/" + model | ||||
| 
 | ||||
|     snapshot_kwargs["repo_id"] = model | ||||
| 
 | ||||
|     # Attempt to query the huggingface_hub library to determine the local path and/or to update | ||||
|     try: | ||||
|         model_repo_path = snapshot_download(**snapshot_kwargs) | ||||
|         log.debug(f"model_repo_path: {model_repo_path}") | ||||
|         return model_repo_path | ||||
|     except Exception as e: | ||||
|         log.exception(f"Cannot determine model snapshot path: {e}") | ||||
|         return model | ||||
| 
 | ||||
| 
 | ||||
| def generate_openai_embeddings( | ||||
|     model: str, text: str, key: str, url: str = "https://api.openai.com/v1" | ||||
| ): | ||||
|  | @ -250,3 +379,97 @@ def generate_openai_embeddings( | |||
|     except Exception as e: | ||||
|         print(e) | ||||
|         return None | ||||
| 
 | ||||
| 
 | ||||
| from typing import Any | ||||
| 
 | ||||
| from langchain_core.retrievers import BaseRetriever | ||||
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | ||||
| 
 | ||||
| 
 | ||||
| class ChromaRetriever(BaseRetriever): | ||||
|     collection: Any | ||||
|     embeddings_function: Any | ||||
|     top_n: int | ||||
| 
 | ||||
|     def _get_relevant_documents( | ||||
|         self, | ||||
|         query: str, | ||||
|         *, | ||||
|         run_manager: CallbackManagerForRetrieverRun, | ||||
|     ) -> List[Document]: | ||||
|         query_embeddings = self.embeddings_function(query) | ||||
| 
 | ||||
|         results = self.collection.query( | ||||
|             query_embeddings=[query_embeddings], | ||||
|             n_results=self.top_n, | ||||
|         ) | ||||
| 
 | ||||
|         ids = results["ids"][0] | ||||
|         metadatas = results["metadatas"][0] | ||||
|         documents = results["documents"][0] | ||||
| 
 | ||||
|         return [ | ||||
|             Document( | ||||
|                 metadata=metadatas[idx], | ||||
|                 page_content=documents[idx], | ||||
|             ) | ||||
|             for idx in range(len(ids)) | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| import operator | ||||
| 
 | ||||
| from typing import Optional, Sequence | ||||
| 
 | ||||
| from langchain_core.documents import BaseDocumentCompressor, Document | ||||
| from langchain_core.callbacks import Callbacks | ||||
| from langchain_core.pydantic_v1 import Extra | ||||
| 
 | ||||
| from sentence_transformers import util | ||||
| 
 | ||||
| 
 | ||||
| class RerankCompressor(BaseDocumentCompressor): | ||||
|     embeddings_function: Any | ||||
|     reranking_function: Any | ||||
|     r_score: float | ||||
|     top_n: int | ||||
| 
 | ||||
|     class Config: | ||||
|         extra = Extra.forbid | ||||
|         arbitrary_types_allowed = True | ||||
| 
 | ||||
|     def compress_documents( | ||||
|         self, | ||||
|         documents: Sequence[Document], | ||||
|         query: str, | ||||
|         callbacks: Optional[Callbacks] = None, | ||||
|     ) -> Sequence[Document]: | ||||
|         if self.reranking_function: | ||||
|             scores = self.reranking_function.predict( | ||||
|                 [(query, doc.page_content) for doc in documents] | ||||
|             ) | ||||
|         else: | ||||
|             query_embedding = self.embeddings_function(query) | ||||
|             document_embedding = self.embeddings_function( | ||||
|                 [doc.page_content for doc in documents] | ||||
|             ) | ||||
|             scores = util.cos_sim(query_embedding, document_embedding)[0] | ||||
| 
 | ||||
|         docs_with_scores = list(zip(documents, scores.tolist())) | ||||
|         if self.r_score: | ||||
|             docs_with_scores = [ | ||||
|                 (d, s) for d, s in docs_with_scores if s >= self.r_score | ||||
|             ] | ||||
| 
 | ||||
|         result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True) | ||||
|         final_results = [] | ||||
|         for doc, doc_score in result[: self.top_n]: | ||||
|             metadata = doc.metadata | ||||
|             metadata["score"] = doc_score | ||||
|             doc = Document( | ||||
|                 page_content=doc.page_content, | ||||
|                 metadata=metadata, | ||||
|             ) | ||||
|             final_results.append(doc) | ||||
|         return final_results | ||||
|  |  | |||
|  | @ -420,6 +420,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": | |||
| CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | ||||
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) | ||||
| 
 | ||||
| RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) | ||||
| RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) | ||||
| 
 | ||||
| RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") | ||||
| 
 | ||||
| RAG_EMBEDDING_MODEL = os.environ.get( | ||||
|  | @ -427,10 +430,26 @@ RAG_EMBEDDING_MODEL = os.environ.get( | |||
| ) | ||||
| log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), | ||||
| 
 | ||||
| RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( | ||||
|     os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( | ||||
|     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "") | ||||
| if not RAG_RERANKING_MODEL == "": | ||||
|     log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"), | ||||
| 
 | ||||
| RAG_RERANKING_MODEL_AUTO_UPDATE = ( | ||||
|     os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( | ||||
|     os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" | ||||
| ) | ||||
| 
 | ||||
| # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance | ||||
| USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") | ||||
| 
 | ||||
|  | @ -439,16 +458,15 @@ if USE_CUDA.lower() == "true": | |||
| else: | ||||
|     DEVICE_TYPE = "cpu" | ||||
| 
 | ||||
| 
 | ||||
| CHROMA_CLIENT = chromadb.PersistentClient( | ||||
|     path=CHROMA_DATA_PATH, | ||||
|     settings=Settings(allow_reset=True, anonymized_telemetry=False), | ||||
| ) | ||||
| CHUNK_SIZE = 1500 | ||||
| CHUNK_OVERLAP = 100 | ||||
| 
 | ||||
| CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) | ||||
| CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) | ||||
| 
 | ||||
| RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. | ||||
| DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. | ||||
| <context> | ||||
|     [context] | ||||
| </context> | ||||
|  | @ -462,6 +480,8 @@ And answer according to the language of the user's question. | |||
| Given the context information, answer the query. | ||||
| Query: [query]""" | ||||
| 
 | ||||
| RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) | ||||
| 
 | ||||
| RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) | ||||
| RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) | ||||
| 
 | ||||
|  |  | |||
|  | @ -120,9 +120,11 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|                     data["messages"], | ||||
|                     rag_app.state.RAG_TEMPLATE, | ||||
|                     rag_app.state.TOP_K, | ||||
|                     rag_app.state.RELEVANCE_THRESHOLD, | ||||
|                     rag_app.state.RAG_EMBEDDING_ENGINE, | ||||
|                     rag_app.state.RAG_EMBEDDING_MODEL, | ||||
|                     rag_app.state.sentence_transformer_ef, | ||||
|                     rag_app.state.sentence_transformer_rf, | ||||
|                     rag_app.state.OPENAI_API_KEY, | ||||
|                     rag_app.state.OPENAI_API_BASE_URL, | ||||
|                 ) | ||||
|  |  | |||
|  | @ -123,6 +123,7 @@ export const getQuerySettings = async (token: string) => { | |||
| 
 | ||||
| type QuerySettings = { | ||||
| 	k: number | null; | ||||
| 	r: number | null; | ||||
| 	template: string | null; | ||||
| }; | ||||
| 
 | ||||
|  | @ -413,3 +414,64 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod | |||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const getRerankingConfig = async (token: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/reranking`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		} | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err.detail; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| type RerankingModelUpdateForm = { | ||||
| 	reranking_model: string; | ||||
| }; | ||||
| 
 | ||||
| export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, { | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		}, | ||||
| 		body: JSON.stringify({ | ||||
| 			...payload | ||||
| 		}) | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err.detail; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
|  |  | |||
|  | @ -8,7 +8,9 @@ | |||
| 		updateQuerySettings, | ||||
| 		resetVectorDB, | ||||
| 		getEmbeddingConfig, | ||||
| 		updateEmbeddingConfig | ||||
| 		updateEmbeddingConfig, | ||||
| 		getRerankingConfig, | ||||
| 		updateRerankingConfig | ||||
| 	} from '$lib/apis/rag'; | ||||
| 
 | ||||
| 	import { documents, models } from '$lib/stores'; | ||||
|  | @ -23,11 +25,13 @@ | |||
| 
 | ||||
| 	let scanDirLoading = false; | ||||
| 	let updateEmbeddingModelLoading = false; | ||||
| 	let updateRerankingModelLoading = false; | ||||
| 
 | ||||
| 	let showResetConfirm = false; | ||||
| 
 | ||||
| 	let embeddingEngine = ''; | ||||
| 	let embeddingModel = ''; | ||||
| 	let rerankingModel = ''; | ||||
| 
 | ||||
| 	let OpenAIKey = ''; | ||||
| 	let OpenAIUrl = ''; | ||||
|  | @ -38,6 +42,7 @@ | |||
| 
 | ||||
| 	let querySettings = { | ||||
| 		template: '', | ||||
| 		r: 0.0, | ||||
| 		k: 4 | ||||
| 	}; | ||||
| 
 | ||||
|  | @ -115,6 +120,29 @@ | |||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	const rerankingModelUpdateHandler = async () => { | ||||
| 		console.log('Update reranking model attempt:', rerankingModel); | ||||
| 
 | ||||
| 		updateRerankingModelLoading = true; | ||||
| 		const res = await updateRerankingConfig(localStorage.token, { | ||||
| 			reranking_model: rerankingModel | ||||
| 		}).catch(async (error) => { | ||||
| 			toast.error(error); | ||||
| 			await setRerankingConfig(); | ||||
| 			return null; | ||||
| 		}); | ||||
| 		updateRerankingModelLoading = false; | ||||
| 
 | ||||
| 		if (res) { | ||||
| 			console.log('rerankingModelUpdateHandler:', res); | ||||
| 			if (res.status === true) { | ||||
| 				toast.success($i18n.t('Reranking model set to "{{reranking_model}}"', res), { | ||||
| 					duration: 1000 * 10 | ||||
| 				}); | ||||
| 			} | ||||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	const submitHandler = async () => { | ||||
| 		const res = await updateRAGConfig(localStorage.token, { | ||||
| 			pdf_extract_images: pdfExtractImages, | ||||
|  | @ -138,6 +166,14 @@ | |||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	const setRerankingConfig = async () => { | ||||
| 		const rerankingConfig = await getRerankingConfig(localStorage.token); | ||||
| 
 | ||||
| 		if (rerankingConfig) { | ||||
| 			rerankingModel = rerankingConfig.reranking_model; | ||||
| 		} | ||||
| 	}; | ||||
| 
 | ||||
| 	onMount(async () => { | ||||
| 		const res = await getRAGConfig(localStorage.token); | ||||
| 
 | ||||
|  | @ -149,6 +185,7 @@ | |||
| 		} | ||||
| 
 | ||||
| 		await setEmbeddingConfig(); | ||||
| 		await setRerankingConfig(); | ||||
| 
 | ||||
| 		querySettings = await getQuerySettings(localStorage.token); | ||||
| 	}); | ||||
|  | @ -349,6 +386,79 @@ | |||
| 
 | ||||
| 				<hr class=" dark:border-gray-700 my-3" /> | ||||
| 
 | ||||
| 				<div class=" "> | ||||
| 					<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Reranking Model')}</div> | ||||
| 
 | ||||
| 					<div class="flex w-full"> | ||||
| 						<div class="flex-1 mr-2"> | ||||
| 							<input | ||||
| 								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||
| 								placeholder={$i18n.t('Update reranking model (e.g. {{model}})', { | ||||
| 									model: rerankingModel.slice(-40) | ||||
| 								})} | ||||
| 								bind:value={rerankingModel} | ||||
| 							/> | ||||
| 						</div> | ||||
| 						<button | ||||
| 							class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition" | ||||
| 							on:click={() => { | ||||
| 								rerankingModelUpdateHandler(); | ||||
| 							}} | ||||
| 							disabled={updateRerankingModelLoading} | ||||
| 						> | ||||
| 							{#if updateRerankingModelLoading} | ||||
| 								<div class="self-center"> | ||||
| 									<svg | ||||
| 										class=" w-4 h-4" | ||||
| 										viewBox="0 0 24 24" | ||||
| 										fill="currentColor" | ||||
| 										xmlns="http://www.w3.org/2000/svg" | ||||
| 										><style> | ||||
| 											.spinner_ajPY { | ||||
| 												transform-origin: center; | ||||
| 												animation: spinner_AtaB 0.75s infinite linear; | ||||
| 											} | ||||
| 											@keyframes spinner_AtaB { | ||||
| 												100% { | ||||
| 													transform: rotate(360deg); | ||||
| 												} | ||||
| 											} | ||||
| 										</style><path | ||||
| 											d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z" | ||||
| 											opacity=".25" | ||||
| 										/><path | ||||
| 											d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z" | ||||
| 											class="spinner_ajPY" | ||||
| 										/></svg | ||||
| 									> | ||||
| 								</div> | ||||
| 							{:else} | ||||
| 								<svg | ||||
| 									xmlns="http://www.w3.org/2000/svg" | ||||
| 									viewBox="0 0 16 16" | ||||
| 									fill="currentColor" | ||||
| 									class="w-4 h-4" | ||||
| 								> | ||||
| 									<path | ||||
| 										d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z" | ||||
| 									/> | ||||
| 									<path | ||||
| 										d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z" | ||||
| 									/> | ||||
| 								</svg> | ||||
| 							{/if} | ||||
| 						</button> | ||||
| 					</div> | ||||
| 				</div> | ||||
| 
 | ||||
| 				<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500"> | ||||
| 					{$i18n.t( | ||||
| 						'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.' | ||||
| 					)} | ||||
| 				</div> | ||||
| 
 | ||||
| 				<hr class=" dark:border-gray-700 my-3" /> | ||||
| 
 | ||||
| 				<div class="  flex w-full justify-between"> | ||||
| 					<div class=" self-center text-xs font-medium"> | ||||
| 						{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })} | ||||
|  | @ -473,6 +583,26 @@ | |||
| 						</div> | ||||
| 					</div> | ||||
| 
 | ||||
| 					<div class=" flex"> | ||||
| 						<div class="  flex w-full justify-between"> | ||||
| 							<div class="self-center text-xs font-medium flex-1"> | ||||
| 								{$i18n.t('Relevance Threshold')} | ||||
| 							</div> | ||||
| 
 | ||||
| 							<div class="self-center p-3"> | ||||
| 								<input | ||||
| 									class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none" | ||||
| 									type="number" | ||||
| 									step="0.01" | ||||
| 									placeholder={$i18n.t('Enter Relevance Threshold')} | ||||
| 									bind:value={querySettings.r} | ||||
| 									autocomplete="off" | ||||
| 									min="0.0" | ||||
| 								/> | ||||
| 							</div> | ||||
| 						</div> | ||||
| 					</div> | ||||
| 
 | ||||
| 					<div> | ||||
| 						<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div> | ||||
| 						<textarea | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy Jaeryang Baek
						Timothy Jaeryang Baek