forked from open-webui/open-webui
		
	feat: add rag top k value setting
This commit is contained in:
		
							parent
							
								
									9694c6569f
								
							
						
					
					
						commit
						47a05a47b4
					
				
					 5 changed files with 123 additions and 38 deletions
				
			
		|  | @ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE | |||
| app.state.CHUNK_OVERLAP = CHUNK_OVERLAP | ||||
| app.state.RAG_TEMPLATE = RAG_TEMPLATE | ||||
| app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL | ||||
| app.state.TOP_K = 4 | ||||
| 
 | ||||
| app.state.sentence_transformer_ef = ( | ||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|  | @ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)): | |||
|     } | ||||
| 
 | ||||
| 
 | ||||
| class RAGTemplateForm(BaseModel): | ||||
|     template: str | ||||
| @app.get("/query/settings") | ||||
| async def get_query_settings(user=Depends(get_admin_user)): | ||||
|     return { | ||||
|         "status": True, | ||||
|         "template": app.state.RAG_TEMPLATE, | ||||
|         "k": app.state.TOP_K, | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/template/update") | ||||
| async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)): | ||||
|     # TODO: check template requirements | ||||
|     app.state.RAG_TEMPLATE = ( | ||||
|         form_data.template if form_data.template != "" else RAG_TEMPLATE | ||||
|     ) | ||||
| class QuerySettingsForm(BaseModel): | ||||
|     k: Optional[int] = None | ||||
|     template: Optional[str] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/settings/update") | ||||
| async def update_query_settings( | ||||
|     form_data: QuerySettingsForm, user=Depends(get_admin_user) | ||||
| ): | ||||
|     app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE | ||||
|     app.state.TOP_K = form_data.k if form_data.k else 4 | ||||
|     return {"status": True, "template": app.state.RAG_TEMPLATE} | ||||
| 
 | ||||
| 
 | ||||
| class QueryDocForm(BaseModel): | ||||
|     collection_name: str | ||||
|     query: str | ||||
|     k: Optional[int] = 4 | ||||
|     k: Optional[int] = None | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/doc") | ||||
|  | @ -240,7 +252,10 @@ def query_doc( | |||
|             name=form_data.collection_name, | ||||
|             embedding_function=app.state.sentence_transformer_ef, | ||||
|         ) | ||||
|         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||
|         result = collection.query( | ||||
|             query_texts=[form_data.query], | ||||
|             n_results=form_data.k if form_data.k else app.state.TOP_K, | ||||
|         ) | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|  | @ -253,7 +268,7 @@ def query_doc( | |||
| class QueryCollectionsForm(BaseModel): | ||||
|     collection_names: List[str] | ||||
|     query: str | ||||
|     k: Optional[int] = 4 | ||||
|     k: Optional[int] = None | ||||
| 
 | ||||
| 
 | ||||
| def merge_and_sort_query_results(query_results, k): | ||||
|  | @ -317,13 +332,16 @@ def query_collection( | |||
|             ) | ||||
| 
 | ||||
|             result = collection.query( | ||||
|                 query_texts=[form_data.query], n_results=form_data.k | ||||
|                 query_texts=[form_data.query], | ||||
|                 n_results=form_data.k if form_data.k else app.state.TOP_K, | ||||
|             ) | ||||
|             results.append(result) | ||||
|         except: | ||||
|             pass | ||||
| 
 | ||||
|     return merge_and_sort_query_results(results, form_data.k) | ||||
|     return merge_and_sort_query_results( | ||||
|         results, form_data.k if form_data.k else app.state.TOP_K | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/web") | ||||
|  | @ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): | |||
|         "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", | ||||
|     ] or file_ext in ["xls", "xlsx"]: | ||||
|         loader = UnstructuredExcelLoader(file_path) | ||||
|     elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0): | ||||
|     elif file_ext in known_source_ext or ( | ||||
|         file_content_type and file_content_type.find("text/") >= 0 | ||||
|     ): | ||||
|         loader = TextLoader(file_path) | ||||
|     else: | ||||
|         loader = TextLoader(file_path) | ||||
|  |  | |||
|  | @ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => { | |||
| 	return res?.template ?? ''; | ||||
| }; | ||||
| 
 | ||||
| export const updateRAGTemplate = async (token: string, template: string) => { | ||||
| export const getQuerySettings = async (token: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/template/update`, { | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/query/settings`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		} | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			console.log(err); | ||||
| 			error = err.detail; | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| type QuerySettings = { | ||||
| 	k: number | null; | ||||
| 	template: string | null; | ||||
| }; | ||||
| 
 | ||||
| export const updateQuerySettings = async (token: string, settings: QuerySettings) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/query/settings/update`, { | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			'Content-Type': 'application/json', | ||||
| 			Authorization: `Bearer ${token}` | ||||
| 		}, | ||||
| 		body: JSON.stringify({ | ||||
| 			template: template | ||||
| 			...settings | ||||
| 		}) | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
|  | @ -183,7 +215,7 @@ export const queryDoc = async ( | |||
| 	token: string, | ||||
| 	collection_name: string, | ||||
| 	query: string, | ||||
| 	k: number | ||||
| 	k: number | null = null | ||||
| ) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,10 +2,10 @@ | |||
| 	import { getDocs } from '$lib/apis/documents'; | ||||
| 	import { | ||||
| 		getChunkParams, | ||||
| 		getRAGTemplate, | ||||
| 		getQuerySettings, | ||||
| 		scanDocs, | ||||
| 		updateChunkParams, | ||||
| 		updateRAGTemplate | ||||
| 		updateQuerySettings | ||||
| 	} from '$lib/apis/rag'; | ||||
| 	import { documents } from '$lib/stores'; | ||||
| 	import { onMount } from 'svelte'; | ||||
|  | @ -18,7 +18,10 @@ | |||
| 	let chunkSize = 0; | ||||
| 	let chunkOverlap = 0; | ||||
| 
 | ||||
| 	let template = ''; | ||||
| 	let querySettings = { | ||||
| 		template: '', | ||||
| 		k: 4 | ||||
| 	}; | ||||
| 
 | ||||
| 	const scanHandler = async () => { | ||||
| 		loading = true; | ||||
|  | @ -33,7 +36,7 @@ | |||
| 
 | ||||
| 	const submitHandler = async () => { | ||||
| 		const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap); | ||||
| 		await updateRAGTemplate(localStorage.token, template); | ||||
| 		querySettings = await updateQuerySettings(localStorage.token, querySettings); | ||||
| 	}; | ||||
| 
 | ||||
| 	onMount(async () => { | ||||
|  | @ -44,7 +47,7 @@ | |||
| 			chunkOverlap = res.chunk_overlap; | ||||
| 		} | ||||
| 
 | ||||
| 		template = await getRAGTemplate(localStorage.token); | ||||
| 		querySettings = await getQuerySettings(localStorage.token); | ||||
| 	}); | ||||
| </script> | ||||
| 
 | ||||
|  | @ -156,10 +159,44 @@ | |||
| 				</div> | ||||
| 			</div> | ||||
| 
 | ||||
| 			<div class=" text-sm font-medium">Query Params</div> | ||||
| 
 | ||||
| 			<div class=" flex"> | ||||
| 				<div class="  flex w-full justify-between"> | ||||
| 					<div class="self-center text-xs font-medium flex-1">Top K</div> | ||||
| 
 | ||||
| 					<div class="self-center p-3"> | ||||
| 						<input | ||||
| 							class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600" | ||||
| 							type="number" | ||||
| 							placeholder="Enter Top K" | ||||
| 							bind:value={querySettings.k} | ||||
| 							autocomplete="off" | ||||
| 							min="0" | ||||
| 						/> | ||||
| 					</div> | ||||
| 				</div> | ||||
| 
 | ||||
| 				<!-- <div class="flex w-full"> | ||||
| 					<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div> | ||||
| 
 | ||||
| 					<div class="self-center p-3"> | ||||
| 						<input | ||||
| 							class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600" | ||||
| 							type="number" | ||||
| 							placeholder="Enter Chunk Overlap" | ||||
| 							bind:value={chunkOverlap} | ||||
| 							autocomplete="off" | ||||
| 							min="0" | ||||
| 						/> | ||||
| 					</div> | ||||
| 				</div> --> | ||||
| 			</div> | ||||
| 
 | ||||
| 			<div> | ||||
| 				<div class=" mb-2.5 text-sm font-medium">RAG Template</div> | ||||
| 				<textarea | ||||
| 					bind:value={template} | ||||
| 					bind:value={querySettings.template} | ||||
| 					class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none" | ||||
| 					rows="4" | ||||
| 				/> | ||||
|  |  | |||
|  | @ -248,19 +248,17 @@ | |||
| 			let relevantContexts = await Promise.all( | ||||
| 				docs.map(async (doc) => { | ||||
| 					if (doc.type === 'collection') { | ||||
| 						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( | ||||
| 						return await queryCollection(localStorage.token, doc.collection_names, query).catch( | ||||
| 							(error) => { | ||||
| 								console.log(error); | ||||
| 								return null; | ||||
| 							} | ||||
| 						); | ||||
| 					} else { | ||||
| 						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( | ||||
| 							(error) => { | ||||
| 								console.log(error); | ||||
| 								return null; | ||||
| 							} | ||||
| 						); | ||||
| 						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { | ||||
| 							console.log(error); | ||||
| 							return null; | ||||
| 						}); | ||||
| 					} | ||||
| 				}) | ||||
| 			); | ||||
|  |  | |||
|  | @ -261,19 +261,17 @@ | |||
| 			let relevantContexts = await Promise.all( | ||||
| 				docs.map(async (doc) => { | ||||
| 					if (doc.type === 'collection') { | ||||
| 						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( | ||||
| 						return await queryCollection(localStorage.token, doc.collection_names, query).catch( | ||||
| 							(error) => { | ||||
| 								console.log(error); | ||||
| 								return null; | ||||
| 							} | ||||
| 						); | ||||
| 					} else { | ||||
| 						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( | ||||
| 							(error) => { | ||||
| 								console.log(error); | ||||
| 								return null; | ||||
| 							} | ||||
| 						); | ||||
| 						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { | ||||
| 							console.log(error); | ||||
| 							return null; | ||||
| 						}); | ||||
| 					} | ||||
| 				}) | ||||
| 			); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek