forked from open-webui/open-webui
		
	refac
This commit is contained in:
		
							parent
							
								
									485236624f
								
							
						
					
					
						commit
						50f7b20ac2
					
				
					 4 changed files with 91 additions and 58 deletions
				
			
		|  | @ -10,6 +10,7 @@ from fastapi import ( | ||||||
| ) | ) | ||||||
| from fastapi.middleware.cors import CORSMiddleware | from fastapi.middleware.cors import CORSMiddleware | ||||||
| import os, shutil | import os, shutil | ||||||
|  | from typing import List | ||||||
| 
 | 
 | ||||||
| # from chromadb.utils import embedding_functions | # from chromadb.utils import embedding_functions | ||||||
| 
 | 
 | ||||||
|  | @ -96,19 +97,22 @@ async def get_status(): | ||||||
|     return {"status": True} |     return {"status": True} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.get("/query/{collection_name}") | class QueryCollectionForm(BaseModel): | ||||||
|  |     collection_name: str | ||||||
|  |     query: str | ||||||
|  |     k: Optional[int] = 4 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.post("/query/collection") | ||||||
| def query_collection( | def query_collection( | ||||||
|     collection_name: str, |     form_data: QueryCollectionForm, | ||||||
|     query: str, |  | ||||||
|     k: Optional[int] = 4, |  | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
|         collection = CHROMA_CLIENT.get_collection( |         collection = CHROMA_CLIENT.get_collection( | ||||||
|             name=collection_name, |             name=form_data.collection_name, | ||||||
|         ) |         ) | ||||||
|         result = collection.query(query_texts=[query], n_results=k) |         result = collection.query(query_texts=[form_data.query], n_results=form_data.k) | ||||||
| 
 |  | ||||||
|         return result |         return result | ||||||
|     except Exception as e: |     except Exception as e: | ||||||
|         print(e) |         print(e) | ||||||
|  | @ -118,6 +122,34 @@ def query_collection( | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | class QueryCollectionsForm(BaseModel): | ||||||
|  |     collection_names: List[str] | ||||||
|  |     query: str | ||||||
|  |     k: Optional[int] = 4 | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @app.post("/query/collections") | ||||||
|  | def query_collections( | ||||||
|  |     form_data: QueryCollectionsForm, | ||||||
|  |     user=Depends(get_current_user), | ||||||
|  | ): | ||||||
|  |     results = [] | ||||||
|  | 
 | ||||||
|  |     for collection_name in form_data.collection_names: | ||||||
|  |         try: | ||||||
|  |             collection = CHROMA_CLIENT.get_collection( | ||||||
|  |                 name=collection_name, | ||||||
|  |             ) | ||||||
|  |             result = collection.query( | ||||||
|  |                 query_texts=[form_data.query], n_results=form_data.k | ||||||
|  |             ) | ||||||
|  |             results.append(result) | ||||||
|  |         except: | ||||||
|  |             pass | ||||||
|  | 
 | ||||||
|  |     return results | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| @app.post("/web") | @app.post("/web") | ||||||
| def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): | ||||||
|     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" |     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" | ||||||
|  |  | ||||||
|  | @ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string | ||||||
| 
 | 
 | ||||||
| export const queryVectorDB = async ( | export const queryVectorDB = async ( | ||||||
| 	token: string, | 	token: string, | ||||||
| 	collection_name: string, | 	collection_names: string[], | ||||||
| 	query: string, | 	query: string, | ||||||
| 	k: number | 	k: number | ||||||
| ) => { | ) => { | ||||||
| 	let error = null; | 	let error = null; | ||||||
| 	const searchParams = new URLSearchParams(); |  | ||||||
| 
 | 
 | ||||||
| 	searchParams.set('query', query); | 	const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, { | ||||||
| 	if (k) { | 		method: 'POST', | ||||||
| 		searchParams.set('k', k.toString()); | 		headers: { | ||||||
| 	} | 			Accept: 'application/json', | ||||||
| 
 | 			'Content-Type': 'application/json', | ||||||
| 	const res = await fetch( | 			authorization: `Bearer ${token}` | ||||||
| 		`${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, | 		}, | ||||||
| 		{ | 		body: JSON.stringify({ | ||||||
| 			method: 'GET', | 			collection_names: collection_names, | ||||||
| 			headers: { | 			query: query, | ||||||
| 				Accept: 'application/json', | 			k: k | ||||||
| 				authorization: `Bearer ${token}` | 		}) | ||||||
| 			} | 	}) | ||||||
| 		} |  | ||||||
| 	) |  | ||||||
| 		.then(async (res) => { | 		.then(async (res) => { | ||||||
| 			if (!res.ok) throw await res.json(); | 			if (!res.ok) throw await res.json(); | ||||||
| 			return res.json(); | 			return res.json(); | ||||||
|  |  | ||||||
|  | @ -232,26 +232,28 @@ | ||||||
| 			processing = 'Reading'; | 			processing = 'Reading'; | ||||||
| 			const query = history.messages[parentId].content; | 			const query = history.messages[parentId].content; | ||||||
| 
 | 
 | ||||||
| 			let relevantContexts = await Promise.all( | 			let relevantContexts = await queryVectorDB( | ||||||
| 				docs.map(async (doc) => { | 				localStorage.token, | ||||||
| 					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( | 				docs.map((d) => d.collection_name), | ||||||
| 						(error) => { | 				query, | ||||||
| 							console.log(error); | 				4 | ||||||
| 							return null; | 			).catch((error) => { | ||||||
| 						} | 				console.log(error); | ||||||
| 					); | 				return null; | ||||||
| 				}) | 			}); | ||||||
| 			); |  | ||||||
| 			relevantContexts = relevantContexts.filter((context) => context); |  | ||||||
| 
 | 
 | ||||||
| 			const contextString = relevantContexts.reduce((a, context, i, arr) => { | 			if (relevantContexts) { | ||||||
| 				return `${a}${context.documents.join(' ')}\n`; | 				relevantContexts = relevantContexts.filter((context) => context); | ||||||
| 			}, ''); |  | ||||||
| 
 | 
 | ||||||
| 			console.log(contextString); | 				const contextString = relevantContexts.reduce((a, context, i, arr) => { | ||||||
|  | 					return `${a}${context.documents.join(' ')}\n`; | ||||||
|  | 				}, ''); | ||||||
| 
 | 
 | ||||||
| 			history.messages[parentId].raContent = RAGTemplate(contextString, query); | 				console.log(contextString); | ||||||
| 			history.messages[parentId].contexts = relevantContexts; | 
 | ||||||
|  | 				history.messages[parentId].raContent = RAGTemplate(contextString, query); | ||||||
|  | 				history.messages[parentId].contexts = relevantContexts; | ||||||
|  | 			} | ||||||
| 			await tick(); | 			await tick(); | ||||||
| 			processing = ''; | 			processing = ''; | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
|  | @ -246,26 +246,28 @@ | ||||||
| 			processing = 'Reading'; | 			processing = 'Reading'; | ||||||
| 			const query = history.messages[parentId].content; | 			const query = history.messages[parentId].content; | ||||||
| 
 | 
 | ||||||
| 			let relevantContexts = await Promise.all( | 			let relevantContexts = await queryVectorDB( | ||||||
| 				docs.map(async (doc) => { | 				localStorage.token, | ||||||
| 					return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( | 				docs.map((d) => d.collection_name), | ||||||
| 						(error) => { | 				query, | ||||||
| 							console.log(error); | 				4 | ||||||
| 							return null; | 			).catch((error) => { | ||||||
| 						} | 				console.log(error); | ||||||
| 					); | 				return null; | ||||||
| 				}) | 			}); | ||||||
| 			); |  | ||||||
| 			relevantContexts = relevantContexts.filter((context) => context); |  | ||||||
| 
 | 
 | ||||||
| 			const contextString = relevantContexts.reduce((a, context, i, arr) => { | 			if (relevantContexts) { | ||||||
| 				return `${a}${context.documents.join(' ')}\n`; | 				relevantContexts = relevantContexts.filter((context) => context); | ||||||
| 			}, ''); |  | ||||||
| 
 | 
 | ||||||
| 			console.log(contextString); | 				const contextString = relevantContexts.reduce((a, context, i, arr) => { | ||||||
|  | 					return `${a}${context.documents.join(' ')}\n`; | ||||||
|  | 				}, ''); | ||||||
| 
 | 
 | ||||||
| 			history.messages[parentId].raContent = RAGTemplate(contextString, query); | 				console.log(contextString); | ||||||
| 			history.messages[parentId].contexts = relevantContexts; | 
 | ||||||
|  | 				history.messages[parentId].raContent = RAGTemplate(contextString, query); | ||||||
|  | 				history.messages[parentId].contexts = relevantContexts; | ||||||
|  | 			} | ||||||
| 			await tick(); | 			await tick(); | ||||||
| 			processing = ''; | 			processing = ''; | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek