forked from open-webui/open-webui
		
	feat: collection rag integration
This commit is contained in:
		
							parent
							
								
									7d2f788a3b
								
							
						
					
					
						commit
						683650ec00
					
				
					 5 changed files with 112 additions and 24 deletions
				
			
		|  | @ -97,15 +97,15 @@ async def get_status(): | ||||||
|     return {"status": True} |     return {"status": True} | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class QueryCollectionForm(BaseModel): | class QueryDocForm(BaseModel): | ||||||
|     collection_name: str |     collection_name: str | ||||||
|     query: str |     query: str | ||||||
|     k: Optional[int] = 4 |     k: Optional[int] = 4 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/query/collection") | @app.post("/query/doc") | ||||||
| def query_collection( | def query_doc( | ||||||
|     form_data: QueryCollectionForm, |     form_data: QueryDocForm, | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
| ): | ): | ||||||
|     try: |     try: | ||||||
|  | @ -173,8 +173,8 @@ def merge_and_sort_query_results(query_results, k): | ||||||
|     return merged_query_results |     return merged_query_results | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @app.post("/query/collections") | @app.post("/query/collection") | ||||||
| def query_collections( | def query_collection( | ||||||
|     form_data: QueryCollectionsForm, |     form_data: QueryCollectionsForm, | ||||||
|     user=Depends(get_current_user), |     user=Depends(get_current_user), | ||||||
| ): | ): | ||||||
|  |  | ||||||
|  | @ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string | ||||||
| 	return res; | 	return res; | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| export const queryCollection = async ( | export const queryDoc = async ( | ||||||
| 	token: string, | 	token: string, | ||||||
| 	collection_name: string, | 	collection_name: string, | ||||||
| 	query: string, | 	query: string, | ||||||
|  | @ -72,6 +72,43 @@ export const queryCollection = async ( | ||||||
| ) => { | ) => { | ||||||
| 	let error = null; | 	let error = null; | ||||||
| 
 | 
 | ||||||
|  | 	const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { | ||||||
|  | 		method: 'POST', | ||||||
|  | 		headers: { | ||||||
|  | 			Accept: 'application/json', | ||||||
|  | 			'Content-Type': 'application/json', | ||||||
|  | 			authorization: `Bearer ${token}` | ||||||
|  | 		}, | ||||||
|  | 		body: JSON.stringify({ | ||||||
|  | 			collection_name: collection_name, | ||||||
|  | 			query: query, | ||||||
|  | 			k: k | ||||||
|  | 		}) | ||||||
|  | 	}) | ||||||
|  | 		.then(async (res) => { | ||||||
|  | 			if (!res.ok) throw await res.json(); | ||||||
|  | 			return res.json(); | ||||||
|  | 		}) | ||||||
|  | 		.catch((err) => { | ||||||
|  | 			error = err.detail; | ||||||
|  | 			return null; | ||||||
|  | 		}); | ||||||
|  | 
 | ||||||
|  | 	if (error) { | ||||||
|  | 		throw error; | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	return res; | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | export const queryCollection = async ( | ||||||
|  | 	token: string, | ||||||
|  | 	collection_names: string, | ||||||
|  | 	query: string, | ||||||
|  | 	k: number | ||||||
|  | ) => { | ||||||
|  | 	let error = null; | ||||||
|  | 
 | ||||||
| 	const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { | 	const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { | ||||||
| 		method: 'POST', | 		method: 'POST', | ||||||
| 		headers: { | 		headers: { | ||||||
|  | @ -80,7 +117,7 @@ export const queryCollection = async ( | ||||||
| 			authorization: `Bearer ${token}` | 			authorization: `Bearer ${token}` | ||||||
| 		}, | 		}, | ||||||
| 		body: JSON.stringify({ | 		body: JSON.stringify({ | ||||||
| 			collection_name: collection_name, | 			collection_names: collection_names, | ||||||
| 			query: query, | 			query: query, | ||||||
| 			k: k | 			k: k | ||||||
| 		}) | 		}) | ||||||
|  |  | ||||||
|  | @ -117,6 +117,35 @@ | ||||||
| 										<div class=" text-gray-500 text-sm">Document</div> | 										<div class=" text-gray-500 text-sm">Document</div> | ||||||
| 									</div> | 									</div> | ||||||
| 								</button> | 								</button> | ||||||
|  | 							{:else if file.type === 'collection'} | ||||||
|  | 								<button | ||||||
|  | 									class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left" | ||||||
|  | 									type="button" | ||||||
|  | 								> | ||||||
|  | 									<div class="p-2.5 bg-red-400 text-white rounded-lg"> | ||||||
|  | 										<svg | ||||||
|  | 											xmlns="http://www.w3.org/2000/svg" | ||||||
|  | 											viewBox="0 0 24 24" | ||||||
|  | 											fill="currentColor" | ||||||
|  | 											class="w-6 h-6" | ||||||
|  | 										> | ||||||
|  | 											<path | ||||||
|  | 												d="M7.5 3.375c0-1.036.84-1.875 1.875-1.875h.375a3.75 3.75 0 0 1 3.75 3.75v1.875C13.5 8.161 14.34 9 15.375 9h1.875A3.75 3.75 0 0 1 21 12.75v3.375C21 17.16 20.16 18 19.125 18h-9.75A1.875 1.875 0 0 1 7.5 16.125V3.375Z" | ||||||
|  | 											/> | ||||||
|  | 											<path | ||||||
|  | 												d="M15 5.25a5.23 5.23 0 0 0-1.279-3.434 9.768 9.768 0 0 1 6.963 6.963A5.23 5.23 0 0 0 17.25 7.5h-1.875A.375.375 0 0 1 15 7.125V5.25ZM4.875 6H6v10.125A3.375 3.375 0 0 0 9.375 19.5H16.5v1.125c0 1.035-.84 1.875-1.875 1.875h-9.75A1.875 1.875 0 0 1 3 20.625V7.875C3 6.839 3.84 6 4.875 6Z" | ||||||
|  | 											/> | ||||||
|  | 										</svg> | ||||||
|  | 									</div> | ||||||
|  | 
 | ||||||
|  | 									<div class="flex flex-col justify-center -space-y-0.5"> | ||||||
|  | 										<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1"> | ||||||
|  | 											#{file.name} | ||||||
|  | 										</div> | ||||||
|  | 
 | ||||||
|  | 										<div class=" text-gray-500 text-sm">Collection</div> | ||||||
|  | 									</div> | ||||||
|  | 								</button> | ||||||
| 							{/if} | 							{/if} | ||||||
| 						</div> | 						</div> | ||||||
| 					{/each} | 					{/each} | ||||||
|  |  | ||||||
|  | @ -28,7 +28,7 @@ | ||||||
| 		getTagsById, | 		getTagsById, | ||||||
| 		updateChatById | 		updateChatById | ||||||
| 	} from '$lib/apis/chats'; | 	} from '$lib/apis/chats'; | ||||||
| 	import { queryCollection } from '$lib/apis/rag'; | 	import { queryCollection, queryDoc } from '$lib/apis/rag'; | ||||||
| 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | ||||||
| 
 | 
 | ||||||
| 	import MessageInput from '$lib/components/chat/MessageInput.svelte'; | 	import MessageInput from '$lib/components/chat/MessageInput.svelte'; | ||||||
|  | @ -224,7 +224,9 @@ | ||||||
| 
 | 
 | ||||||
| 		const docs = messages | 		const docs = messages | ||||||
| 			.filter((message) => message?.files ?? null) | 			.filter((message) => message?.files ?? null) | ||||||
| 			.map((message) => message.files.filter((item) => item.type === 'doc')) | 			.map((message) => | ||||||
|  | 				message.files.filter((item) => item.type === 'doc' || item.type === 'collection') | ||||||
|  | 			) | ||||||
| 			.flat(1); | 			.flat(1); | ||||||
| 
 | 
 | ||||||
| 		console.log(docs); | 		console.log(docs); | ||||||
|  | @ -234,12 +236,21 @@ | ||||||
| 
 | 
 | ||||||
| 			let relevantContexts = await Promise.all( | 			let relevantContexts = await Promise.all( | ||||||
| 				docs.map(async (doc) => { | 				docs.map(async (doc) => { | ||||||
| 					return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( | 					if (doc.type === 'collection') { | ||||||
|  | 						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( | ||||||
| 							(error) => { | 							(error) => { | ||||||
| 								console.log(error); | 								console.log(error); | ||||||
| 								return null; | 								return null; | ||||||
| 							} | 							} | ||||||
| 						); | 						); | ||||||
|  | 					} else { | ||||||
|  | 						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( | ||||||
|  | 							(error) => { | ||||||
|  | 								console.log(error); | ||||||
|  | 								return null; | ||||||
|  | 							} | ||||||
|  | 						); | ||||||
|  | 					} | ||||||
| 				}) | 				}) | ||||||
| 			); | 			); | ||||||
| 			relevantContexts = relevantContexts.filter((context) => context); | 			relevantContexts = relevantContexts.filter((context) => context); | ||||||
|  |  | ||||||
|  | @ -29,7 +29,7 @@ | ||||||
| 		getTagsById, | 		getTagsById, | ||||||
| 		updateChatById | 		updateChatById | ||||||
| 	} from '$lib/apis/chats'; | 	} from '$lib/apis/chats'; | ||||||
| 	import { queryCollection } from '$lib/apis/rag'; | 	import { queryCollection, queryDoc } from '$lib/apis/rag'; | ||||||
| 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | 	import { generateOpenAIChatCompletion } from '$lib/apis/openai'; | ||||||
| 
 | 
 | ||||||
| 	import MessageInput from '$lib/components/chat/MessageInput.svelte'; | 	import MessageInput from '$lib/components/chat/MessageInput.svelte'; | ||||||
|  | @ -238,7 +238,9 @@ | ||||||
| 
 | 
 | ||||||
| 		const docs = messages | 		const docs = messages | ||||||
| 			.filter((message) => message?.files ?? null) | 			.filter((message) => message?.files ?? null) | ||||||
| 			.map((message) => message.files.filter((item) => item.type === 'doc')) | 			.map((message) => | ||||||
|  | 				message.files.filter((item) => item.type === 'doc' || item.type === 'collection') | ||||||
|  | 			) | ||||||
| 			.flat(1); | 			.flat(1); | ||||||
| 
 | 
 | ||||||
| 		console.log(docs); | 		console.log(docs); | ||||||
|  | @ -248,12 +250,21 @@ | ||||||
| 
 | 
 | ||||||
| 			let relevantContexts = await Promise.all( | 			let relevantContexts = await Promise.all( | ||||||
| 				docs.map(async (doc) => { | 				docs.map(async (doc) => { | ||||||
| 					return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( | 					if (doc.type === 'collection') { | ||||||
|  | 						return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( | ||||||
| 							(error) => { | 							(error) => { | ||||||
| 								console.log(error); | 								console.log(error); | ||||||
| 								return null; | 								return null; | ||||||
| 							} | 							} | ||||||
| 						); | 						); | ||||||
|  | 					} else { | ||||||
|  | 						return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( | ||||||
|  | 							(error) => { | ||||||
|  | 								console.log(error); | ||||||
|  | 								return null; | ||||||
|  | 							} | ||||||
|  | 						); | ||||||
|  | 					} | ||||||
| 				}) | 				}) | ||||||
| 			); | 			); | ||||||
| 			relevantContexts = relevantContexts.filter((context) => context); | 			relevantContexts = relevantContexts.filter((context) => context); | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek