forked from open-webui/open-webui
		
	refac
This commit is contained in:
		
							parent
							
								
									485236624f
								
							
						
					
					
						commit
						50f7b20ac2
					
				
					 4 changed files with 91 additions and 58 deletions
				
			
		|  | @ -10,6 +10,7 @@ from fastapi import ( | |||
| ) | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| import os, shutil | ||||
| from typing import List | ||||
| 
 | ||||
| # from chromadb.utils import embedding_functions | ||||
| 
 | ||||
|  | @ -96,19 +97,22 @@ async def get_status(): | |||
|     return {"status": True} | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/query/{collection_name}") | ||||
| class QueryCollectionForm(BaseModel): | ||||
|     collection_name: str | ||||
|     query: str | ||||
|     k: Optional[int] = 4 | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/collection") | ||||
| def query_collection( | ||||
|     collection_name: str, | ||||
|     query: str, | ||||
|     k: Optional[int] = 4, | ||||
|     form_data: QueryCollectionForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     try: | ||||
|         collection = CHROMA_CLIENT.get_collection( | ||||
|             name=collection_name, | ||||
|             name=form_data.collection_name, | ||||
|         ) | ||||
|         result = collection.query(query_texts=[query], n_results=k) | ||||
| 
 | ||||
|         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||
|         return result | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|  | @ -118,6 +122,34 @@ def query_collection( | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class QueryCollectionsForm(BaseModel): | ||||
|     collection_names: List[str] | ||||
|     query: str | ||||
|     k: Optional[int] = 4 | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/collections") | ||||
| def query_collections( | ||||
|     form_data: QueryCollectionsForm, | ||||
|     user=Depends(get_current_user), | ||||
| ): | ||||
|     results = [] | ||||
| 
 | ||||
|     for collection_name in form_data.collection_names: | ||||
|         try: | ||||
|             collection = CHROMA_CLIENT.get_collection( | ||||
|                 name=collection_name, | ||||
|             ) | ||||
|             result = collection.query( | ||||
|                 query_texts=[form_data.query], n_results=form_data.k | ||||
|             ) | ||||
|             results.append(result) | ||||
|         except: | ||||
|             pass | ||||
| 
 | ||||
|     return results | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/web") | ||||
| def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | ||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||
|  |  | |||
|  | @ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string | |||
| 
 | ||||
| export const queryVectorDB = async ( | ||||
| 	token: string, | ||||
| 	collection_name: string, | ||||
| 	collection_names: string[], | ||||
| 	query: string, | ||||
| 	k: number | ||||
| ) => { | ||||
| 	let error = null; | ||||
| 	const searchParams = new URLSearchParams(); | ||||
| 
 | ||||
| 	searchParams.set('query', query); | ||||
| 	if (k) { | ||||
| 		searchParams.set('k', k.toString()); | ||||
| 	} | ||||
| 
 | ||||
| 	const res = await fetch( | ||||
| 		`${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, | ||||
| 		{ | ||||
| 			method: 'GET', | ||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, { | ||||
| 		method: 'POST', | ||||
| 		headers: { | ||||
| 			Accept: 'application/json', | ||||
| 			'Content-Type': 'application/json', | ||||
| 			authorization: `Bearer ${token}` | ||||
| 			} | ||||
| 		} | ||||
| 	) | ||||
| 		}, | ||||
| 		body: JSON.stringify({ | ||||
| 			collection_names: collection_names, | ||||
| 			query: query, | ||||
| 			k: k | ||||
| 		}) | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
|  |  | |||
|  | @ -232,16 +232,17 @@ | |||
| 			processing = 'Reading'; | ||||
| 			const query = history.messages[parentId].content; | ||||
| 
 | ||||
| 			let relevantContexts = await Promise.all( | ||||
| 				docs.map(async (doc) => { | ||||
| 					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( | ||||
| 						(error) => { | ||||
| 			let relevantContexts = await queryVectorDB( | ||||
| 				localStorage.token, | ||||
| 				docs.map((d) => d.collection_name), | ||||
| 				query, | ||||
| 				4 | ||||
| 			).catch((error) => { | ||||
| 				console.log(error); | ||||
| 				return null; | ||||
| 						} | ||||
| 					); | ||||
| 				}) | ||||
| 			); | ||||
| 			}); | ||||
| 
 | ||||
| 			if (relevantContexts) { | ||||
| 				relevantContexts = relevantContexts.filter((context) => context); | ||||
| 
 | ||||
| 				const contextString = relevantContexts.reduce((a, context, i, arr) => { | ||||
|  | @ -252,6 +253,7 @@ | |||
| 
 | ||||
| 				history.messages[parentId].raContent = RAGTemplate(contextString, query); | ||||
| 				history.messages[parentId].contexts = relevantContexts; | ||||
| 			} | ||||
| 			await tick(); | ||||
| 			processing = ''; | ||||
| 		} | ||||
|  |  | |||
|  | @ -246,16 +246,17 @@ | |||
| 			processing = 'Reading'; | ||||
| 			const query = history.messages[parentId].content; | ||||
| 
 | ||||
| 			let relevantContexts = await Promise.all( | ||||
| 				docs.map(async (doc) => { | ||||
| 					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( | ||||
| 						(error) => { | ||||
| 			let relevantContexts = await queryVectorDB( | ||||
| 				localStorage.token, | ||||
| 				docs.map((d) => d.collection_name), | ||||
| 				query, | ||||
| 				4 | ||||
| 			).catch((error) => { | ||||
| 				console.log(error); | ||||
| 				return null; | ||||
| 						} | ||||
| 					); | ||||
| 				}) | ||||
| 			); | ||||
| 			}); | ||||
| 
 | ||||
| 			if (relevantContexts) { | ||||
| 				relevantContexts = relevantContexts.filter((context) => context); | ||||
| 
 | ||||
| 				const contextString = relevantContexts.reduce((a, context, i, arr) => { | ||||
|  | @ -266,6 +267,7 @@ | |||
| 
 | ||||
| 				history.messages[parentId].raContent = RAGTemplate(contextString, query); | ||||
| 				history.messages[parentId].contexts = relevantContexts; | ||||
| 			} | ||||
| 			await tick(); | ||||
| 			processing = ''; | ||||
| 		} | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek