forked from open-webui/open-webui
		
	feat: toggle hybrid search
This commit is contained in:
		
							parent
							
								
									984dbf13ab
								
							
						
					
					
						commit
						9755cd5baa
					
				
					 4 changed files with 133 additions and 88 deletions
				
			
		|  | @ -70,6 +70,7 @@ from config import ( | ||||||
|     RAG_EMBEDDING_MODEL, |     RAG_EMBEDDING_MODEL, | ||||||
|     RAG_EMBEDDING_MODEL_AUTO_UPDATE, |     RAG_EMBEDDING_MODEL_AUTO_UPDATE, | ||||||
|     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, |     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, | ||||||
|  |     RAG_HYBRID, | ||||||
|     RAG_RERANKING_MODEL, |     RAG_RERANKING_MODEL, | ||||||
|     RAG_RERANKING_MODEL_AUTO_UPDATE, |     RAG_RERANKING_MODEL_AUTO_UPDATE, | ||||||
|     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, |     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, | ||||||
|  | @ -91,6 +92,8 @@ app = FastAPI() | ||||||
| 
 | 
 | ||||||
| app.state.TOP_K = RAG_TOP_K | app.state.TOP_K = RAG_TOP_K | ||||||
| app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD | ||||||
|  | app.state.HYBRID = RAG_HYBRID | ||||||
|  | 
 | ||||||
| app.state.CHUNK_SIZE = CHUNK_SIZE | app.state.CHUNK_SIZE = CHUNK_SIZE | ||||||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||||
| 
 | 
 | ||||||
|  | @ -321,6 +324,7 @@ async def get_query_settings(user=Depends(get_admin_user)): | ||||||
|         "template": app.state.RAG_TEMPLATE, |         "template": app.state.RAG_TEMPLATE, | ||||||
|         "k": app.state.TOP_K, |         "k": app.state.TOP_K, | ||||||
|         "r": app.state.RELEVANCE_THRESHOLD, |         "r": app.state.RELEVANCE_THRESHOLD, | ||||||
|  |         "hybrid": app.state.HYBRID, | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -328,6 +332,7 @@ class QuerySettingsForm(BaseModel): | ||||||
|     k: Optional[int] = None |     k: Optional[int] = None | ||||||
|     r: Optional[float] = None |     r: Optional[float] = None | ||||||
|     template: Optional[str] = None |     template: Optional[str] = None | ||||||
|  |     hybrid: Optional[bool] = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/query/settings/update") | @app.post("/query/settings/update") | ||||||
|  | @ -337,7 +342,14 @@ async def update_query_settings( | ||||||
|     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE |     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE | ||||||
|     app.state.TOP_K = form_data.k if form_data.k else 4 |     app.state.TOP_K = form_data.k if form_data.k else 4 | ||||||
|     app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 |     app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 | ||||||
|     return {"status": True, "template": app.state.RAG_TEMPLATE} |     app.state.HYBRID = form_data.hybrid if form_data.hybrid else False | ||||||
|  |     return { | ||||||
|  |         "status": True, | ||||||
|  |         "template": app.state.RAG_TEMPLATE, | ||||||
|  |         "k": app.state.TOP_K, | ||||||
|  |         "r": app.state.RELEVANCE_THRESHOLD, | ||||||
|  |         "hybrid": app.state.HYBRID, | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class QueryDocForm(BaseModel): | class QueryDocForm(BaseModel): | ||||||
|  | @ -345,6 +357,7 @@ class QueryDocForm(BaseModel): | ||||||
|     query: str |     query: str | ||||||
|     k: Optional[int] = None |     k: Optional[int] = None | ||||||
|     r: Optional[float] = None |     r: Optional[float] = None | ||||||
|  |     hybrid: Optional[bool] = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/query/doc") | @app.post("/query/doc") | ||||||
|  | @ -368,6 +381,7 @@ def query_doc_handler( | ||||||
|             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, |             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||||
|             embeddings_function=embeddings_function, |             embeddings_function=embeddings_function, | ||||||
|             reranking_function=app.state.sentence_transformer_rf, |             reranking_function=app.state.sentence_transformer_rf, | ||||||
|  |             hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, | ||||||
|         ) |         ) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         log.exception(e) |         log.exception(e) | ||||||
|  | @ -382,6 +396,7 @@ class QueryCollectionsForm(BaseModel): | ||||||
|     query: str |     query: str | ||||||
|     k: Optional[int] = None |     k: Optional[int] = None | ||||||
|     r: Optional[float] = None |     r: Optional[float] = None | ||||||
|  |     hybrid: Optional[bool] = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/query/collection") | @app.post("/query/collection") | ||||||
|  | @ -405,6 +420,7 @@ def query_collection_handler( | ||||||
|             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, |             r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, | ||||||
|             embeddings_function=embeddings_function, |             embeddings_function=embeddings_function, | ||||||
|             reranking_function=app.state.sentence_transformer_rf, |             reranking_function=app.state.sentence_transformer_rf, | ||||||
|  |             hybrid=form_data.hybrid if form_data.hybrid else app.state.HYBRID, | ||||||
|         ) |         ) | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         log.exception(e) |         log.exception(e) | ||||||
|  |  | ||||||
|  | @ -32,13 +32,13 @@ def query_embeddings_doc( | ||||||
|     collection_name: str, |     collection_name: str, | ||||||
|     query: str, |     query: str, | ||||||
|     embeddings_function, |     embeddings_function, | ||||||
|  |     reranking_function, | ||||||
|     k: int, |     k: int, | ||||||
|     reranking_function: Optional[CrossEncoder] = None, |  | ||||||
|     r: Optional[float] = None, |     r: Optional[float] = None, | ||||||
|  |     hybrid: Optional[bool] = False, | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
| 
 |         if hybrid: | ||||||
|         if reranking_function: |  | ||||||
|             # if you use docker use the model from the environment variable |             # if you use docker use the model from the environment variable | ||||||
|             collection = CHROMA_CLIENT.get_collection(name=collection_name) |             collection = CHROMA_CLIENT.get_collection(name=collection_name) | ||||||
| 
 | 
 | ||||||
|  | @ -142,6 +142,7 @@ def query_embeddings_collection( | ||||||
|     r: float, |     r: float, | ||||||
|     embeddings_function, |     embeddings_function, | ||||||
|     reranking_function, |     reranking_function, | ||||||
|  |     hybrid: bool, | ||||||
| ): | ): | ||||||
| 
 | 
 | ||||||
|     results = [] |     results = [] | ||||||
|  | @ -155,6 +156,7 @@ def query_embeddings_collection( | ||||||
|                 r=r, |                 r=r, | ||||||
|                 embeddings_function=embeddings_function, |                 embeddings_function=embeddings_function, | ||||||
|                 reranking_function=reranking_function, |                 reranking_function=reranking_function, | ||||||
|  |                 hybrid=hybrid, | ||||||
|             ) |             ) | ||||||
|             results.append(result) |             results.append(result) | ||||||
|         except: |         except: | ||||||
|  | @ -211,6 +213,7 @@ def rag_messages( | ||||||
|     template, |     template, | ||||||
|     k, |     k, | ||||||
|     r, |     r, | ||||||
|  |     hybrid, | ||||||
|     embedding_engine, |     embedding_engine, | ||||||
|     embedding_model, |     embedding_model, | ||||||
|     embedding_function, |     embedding_function, | ||||||
|  | @ -283,6 +286,7 @@ def rag_messages( | ||||||
|                     r=r, |                     r=r, | ||||||
|                     embeddings_function=embeddings_function, |                     embeddings_function=embeddings_function, | ||||||
|                     reranking_function=reranking_function, |                     reranking_function=reranking_function, | ||||||
|  |                     hybrid=hybrid, | ||||||
|                 ) |                 ) | ||||||
|             else: |             else: | ||||||
|                 context = query_embeddings_doc( |                 context = query_embeddings_doc( | ||||||
|  | @ -292,6 +296,7 @@ def rag_messages( | ||||||
|                     r=r, |                     r=r, | ||||||
|                     embeddings_function=embeddings_function, |                     embeddings_function=embeddings_function, | ||||||
|                     reranking_function=reranking_function, |                     reranking_function=reranking_function, | ||||||
|  |                     hybrid=hybrid, | ||||||
|                 ) |                 ) | ||||||
|         except Exception as e: |         except Exception as e: | ||||||
|             log.exception(e) |             log.exception(e) | ||||||
|  |  | ||||||
|  | @ -422,6 +422,7 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | ||||||
| 
 | 
 | ||||||
| RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) | RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5")) | ||||||
| RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) | RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")) | ||||||
|  | RAG_HYBRID = os.environ.get("RAG_HYBRID", "").lower() == "true" | ||||||
| 
 | 
 | ||||||
| RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") | RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -43,7 +43,8 @@ | ||||||
| 	let querySettings = { | 	let querySettings = { | ||||||
| 		template: '', | 		template: '', | ||||||
| 		r: 0.0, | 		r: 0.0, | ||||||
| 		k: 4 | 		k: 4, | ||||||
|  | 		hybrid: false | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
| 	const scanHandler = async () => { | 	const scanHandler = async () => { | ||||||
|  | @ -174,6 +175,12 @@ | ||||||
| 		} | 		} | ||||||
| 	}; | 	}; | ||||||
| 
 | 
 | ||||||
|  | 	const toggleHybridSearch = async () => { | ||||||
|  | 		querySettings.hybrid = !querySettings.hybrid; | ||||||
|  | 
 | ||||||
|  | 		querySettings = await updateQuerySettings(localStorage.token, querySettings); | ||||||
|  | 	}; | ||||||
|  | 
 | ||||||
| 	onMount(async () => { | 	onMount(async () => { | ||||||
| 		const res = await getRAGConfig(localStorage.token); | 		const res = await getRAGConfig(localStorage.token); | ||||||
| 
 | 
 | ||||||
|  | @ -202,6 +209,24 @@ | ||||||
| 		<div> | 		<div> | ||||||
| 			<div class=" mb-2 text-sm font-medium">{$i18n.t('General Settings')}</div> | 			<div class=" mb-2 text-sm font-medium">{$i18n.t('General Settings')}</div> | ||||||
| 
 | 
 | ||||||
|  | 			<div class=" flex w-full justify-between"> | ||||||
|  | 				<div class=" self-center text-xs font-medium">{$i18n.t('Hybrid Search')}</div> | ||||||
|  | 
 | ||||||
|  | 				<button | ||||||
|  | 					class="p-1 px-3 text-xs flex rounded transition" | ||||||
|  | 					on:click={() => { | ||||||
|  | 						toggleHybridSearch(); | ||||||
|  | 					}} | ||||||
|  | 					type="button" | ||||||
|  | 				> | ||||||
|  | 					{#if querySettings.hybrid === true} | ||||||
|  | 						<span class="ml-2 self-center">{$i18n.t('On')}</span> | ||||||
|  | 					{:else} | ||||||
|  | 						<span class="ml-2 self-center">{$i18n.t('Off')}</span> | ||||||
|  | 					{/if} | ||||||
|  | 				</button> | ||||||
|  | 			</div> | ||||||
|  | 
 | ||||||
| 			<div class=" flex w-full justify-between"> | 			<div class=" flex w-full justify-between"> | ||||||
| 				<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Model Engine')}</div> | 				<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Model Engine')}</div> | ||||||
| 				<div class="flex items-center relative"> | 				<div class="flex items-center relative"> | ||||||
|  | @ -386,6 +411,7 @@ | ||||||
| 
 | 
 | ||||||
| 				<hr class=" dark:border-gray-700 my-3" /> | 				<hr class=" dark:border-gray-700 my-3" /> | ||||||
| 
 | 
 | ||||||
|  | 				{#if querySettings.hybrid === true} | ||||||
| 					<div class=" "> | 					<div class=" "> | ||||||
| 						<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Reranking Model')}</div> | 						<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Reranking Model')}</div> | ||||||
| 
 | 
 | ||||||
|  | @ -451,13 +477,8 @@ | ||||||
| 						</div> | 						</div> | ||||||
| 					</div> | 					</div> | ||||||
| 
 | 
 | ||||||
| 				<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500"> |  | ||||||
| 					{$i18n.t( |  | ||||||
| 						'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.' |  | ||||||
| 					)} |  | ||||||
| 				</div> |  | ||||||
| 
 |  | ||||||
| 					<hr class=" dark:border-gray-700 my-3" /> | 					<hr class=" dark:border-gray-700 my-3" /> | ||||||
|  | 				{/if} | ||||||
| 
 | 
 | ||||||
| 				<div class="  flex w-full justify-between"> | 				<div class="  flex w-full justify-between"> | ||||||
| 					<div class=" self-center text-xs font-medium"> | 					<div class=" self-center text-xs font-medium"> | ||||||
|  | @ -583,6 +604,7 @@ | ||||||
| 						</div> | 						</div> | ||||||
| 					</div> | 					</div> | ||||||
| 
 | 
 | ||||||
|  | 					{#if querySettings.hybrid === true} | ||||||
| 						<div class=" flex"> | 						<div class=" flex"> | ||||||
| 							<div class="  flex w-full justify-between"> | 							<div class="  flex w-full justify-between"> | ||||||
| 								<div class="self-center text-xs font-medium flex-1"> | 								<div class="self-center text-xs font-medium flex-1"> | ||||||
|  | @ -602,6 +624,7 @@ | ||||||
| 								</div> | 								</div> | ||||||
| 							</div> | 							</div> | ||||||
| 						</div> | 						</div> | ||||||
|  | 					{/if} | ||||||
| 
 | 
 | ||||||
| 					<div> | 					<div> | ||||||
| 						<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div> | 						<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div> | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Steven Kreitzer
						Steven Kreitzer