forked from open-webui/open-webui
		
	refac: rag to backend
This commit is contained in:
		
							parent
							
								
									6ba62cf25d
								
							
						
					
					
						commit
						c49491e516
					
				
					 4 changed files with 113 additions and 50 deletions
				
			
		|  | @ -1,3 +1,4 @@ | |||
| import re | ||||
| from typing import List | ||||
| 
 | ||||
| from config import CHROMA_CLIENT | ||||
|  | @ -87,3 +88,10 @@ def query_collection( | |||
|             pass | ||||
| 
 | ||||
|     return merge_and_sort_query_results(results, k) | ||||
| 
 | ||||
| 
 | ||||
| def rag_template(template: str, context: str, query: str): | ||||
|     template = re.sub(r"\[context\]", context, template) | ||||
|     template = re.sub(r"\[query\]", query, template) | ||||
| 
 | ||||
|     return template | ||||
|  |  | |||
|  | @ -12,6 +12,7 @@ from fastapi import HTTPException | |||
| from fastapi.middleware.wsgi import WSGIMiddleware | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from starlette.exceptions import HTTPException as StarletteHTTPException | ||||
| from starlette.middleware.base import BaseHTTPMiddleware | ||||
| 
 | ||||
| 
 | ||||
| from apps.ollama.main import app as ollama_app | ||||
|  | @ -23,6 +24,8 @@ from apps.rag.main import app as rag_app | |||
| from apps.web.main import app as webui_app | ||||
| 
 | ||||
| 
 | ||||
| from apps.rag.utils import query_doc, query_collection, rag_template | ||||
| 
 | ||||
| from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR | ||||
| from constants import ERROR_MESSAGES | ||||
| 
 | ||||
|  | @ -56,6 +59,89 @@ async def on_startup(): | |||
|     await litellm_app_startup() | ||||
| 
 | ||||
| 
 | ||||
| class RAGMiddleware(BaseHTTPMiddleware): | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
| 
 | ||||
|         print(request.url.path) | ||||
|         if request.method == "POST": | ||||
|             # Read the original request body | ||||
|             body = await request.body() | ||||
|             # Decode body to string | ||||
|             body_str = body.decode("utf-8") | ||||
|             # Parse string to JSON | ||||
|             data = json.loads(body_str) if body_str else {} | ||||
| 
 | ||||
|             # Example: Add a new key-value pair or modify existing ones | ||||
|             # data["modified"] = True  # Example modification | ||||
|             if "docs" in data: | ||||
|                 docs = data["docs"] | ||||
|                 print(docs) | ||||
| 
 | ||||
|                 last_user_message_idx = None | ||||
|                 for i in range(len(data["messages"]) - 1, -1, -1): | ||||
|                     if data["messages"][i]["role"] == "user": | ||||
|                         last_user_message_idx = i | ||||
|                         break | ||||
| 
 | ||||
|                 query = data["messages"][last_user_message_idx]["content"] | ||||
| 
 | ||||
|                 relevant_contexts = [] | ||||
| 
 | ||||
|                 for doc in docs: | ||||
|                     context = None | ||||
|                     if doc["type"] == "collection": | ||||
|                         context = query_collection( | ||||
|                             collection_names=doc["collection_names"], | ||||
|                             query=query, | ||||
|                             k=rag_app.state.TOP_K, | ||||
|                             embedding_function=rag_app.state.sentence_transformer_ef, | ||||
|                         ) | ||||
|                     else: | ||||
|                         context = query_doc( | ||||
|                             collection_name=doc["collection_name"], | ||||
|                             query=query, | ||||
|                             k=rag_app.state.TOP_K, | ||||
|                             embedding_function=rag_app.state.sentence_transformer_ef, | ||||
|                         ) | ||||
|                     relevant_contexts.append(context) | ||||
| 
 | ||||
|                 context_string = "" | ||||
|                 for context in relevant_contexts: | ||||
|                     if context: | ||||
|                         context_string += " ".join(context["documents"][0]) + "\n" | ||||
| 
 | ||||
|                 content = rag_template( | ||||
|                     template=rag_app.state.RAG_TEMPLATE, | ||||
|                     context=context_string, | ||||
|                     query=query, | ||||
|                 ) | ||||
| 
 | ||||
|                 new_user_message = { | ||||
|                     **data["messages"][last_user_message_idx], | ||||
|                     "content": content, | ||||
|                 } | ||||
|                 data["messages"][last_user_message_idx] = new_user_message | ||||
|                 del data["docs"] | ||||
| 
 | ||||
|             print("DATAAAAAAAAAAAAAAAAAA") | ||||
|             print(data) | ||||
|             modified_body_bytes = json.dumps(data).encode("utf-8") | ||||
| 
 | ||||
|             # Create a new request with the modified body | ||||
|             scope = request.scope | ||||
|             scope["body"] = modified_body_bytes | ||||
|             request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) | ||||
| 
 | ||||
|         response = await call_next(request) | ||||
|         return response | ||||
| 
 | ||||
|     async def _receive(self, body: bytes): | ||||
|         return {"type": "http.request", "body": body, "more_body": False} | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware(RAGMiddleware) | ||||
| 
 | ||||
| 
 | ||||
| @app.middleware("http") | ||||
| async def check_url(request: Request, call_next): | ||||
|     start_time = int(time.time()) | ||||
|  |  | |||
|  | @ -252,7 +252,7 @@ export const queryCollection = async ( | |||
| 	token: string, | ||||
| 	collection_names: string, | ||||
| 	query: string, | ||||
| 	k: number | ||||
| 	k: number | null = null | ||||
| ) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
|  |  | |||
|  | @ -232,53 +232,6 @@ | |||
| 	const sendPrompt = async (prompt, parentId) => { | ||||
| 		const _chatId = JSON.parse(JSON.stringify($chatId)); | ||||
| 
 | ||||
| 		const docs = messages | ||||
| 			.filter((message) => message?.files ?? null) | ||||
| 			.map((message) => | ||||
| 				message.files.filter((item) => item.type === 'doc' || item.type === 'collection') | ||||
| 			) | ||||
| 			.flat(1); | ||||
| 
 | ||||
| 		console.log(docs); | ||||
| 		if (docs.length > 0) { | ||||
| 			processing = 'Reading'; | ||||
| 			const query = history.messages[parentId].content; | ||||
| 
 | ||||
| 			let relevantContexts = await Promise.all( | ||||
| 				docs.map(async (doc) => { | ||||
| 					if (doc.type === 'collection') { | ||||
| 						return await queryCollection(localStorage.token, doc.collection_names, query).catch( | ||||
| 							(error) => { | ||||
| 								console.log(error); | ||||
| 								return null; | ||||
| 							} | ||||
| 						); | ||||
| 					} else { | ||||
| 						return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { | ||||
| 							console.log(error); | ||||
| 							return null; | ||||
| 						}); | ||||
| 					} | ||||
| 				}) | ||||
| 			); | ||||
| 			relevantContexts = relevantContexts.filter((context) => context); | ||||
| 
 | ||||
| 			const contextString = relevantContexts.reduce((a, context, i, arr) => { | ||||
| 				return `${a}${context.documents.join(' ')}\n`; | ||||
| 			}, ''); | ||||
| 
 | ||||
| 			console.log(contextString); | ||||
| 
 | ||||
| 			history.messages[parentId].raContent = await RAGTemplate( | ||||
| 				localStorage.token, | ||||
| 				contextString, | ||||
| 				query | ||||
| 			); | ||||
| 			history.messages[parentId].contexts = relevantContexts; | ||||
| 			await tick(); | ||||
| 			processing = ''; | ||||
| 		} | ||||
| 
 | ||||
| 		await Promise.all( | ||||
| 			selectedModels.map(async (modelId) => { | ||||
| 				const model = $models.filter((m) => m.id === modelId).at(0); | ||||
|  | @ -368,6 +321,13 @@ | |||
| 			} | ||||
| 		}); | ||||
| 
 | ||||
| 		const docs = messages | ||||
| 			.filter((message) => message?.files ?? null) | ||||
| 			.map((message) => | ||||
| 				message.files.filter((item) => item.type === 'doc' || item.type === 'collection') | ||||
| 			) | ||||
| 			.flat(1); | ||||
| 
 | ||||
| 		const [res, controller] = await generateChatCompletion(localStorage.token, { | ||||
| 			model: model, | ||||
| 			messages: messagesBody, | ||||
|  | @ -375,7 +335,8 @@ | |||
| 				...($settings.options ?? {}) | ||||
| 			}, | ||||
| 			format: $settings.requestFormat ?? undefined, | ||||
| 			keep_alive: $settings.keepAlive ?? undefined | ||||
| 			keep_alive: $settings.keepAlive ?? undefined, | ||||
| 			docs: docs | ||||
| 		}); | ||||
| 
 | ||||
| 		if (res && res.ok) { | ||||
|  | @ -535,6 +496,13 @@ | |||
| 		const responseMessage = history.messages[responseMessageId]; | ||||
| 		scrollToBottom(); | ||||
| 
 | ||||
| 		const docs = messages | ||||
| 			.filter((message) => message?.files ?? null) | ||||
| 			.map((message) => | ||||
| 				message.files.filter((item) => item.type === 'doc' || item.type === 'collection') | ||||
| 			) | ||||
| 			.flat(1); | ||||
| 
 | ||||
| 		const res = await generateOpenAIChatCompletion( | ||||
| 			localStorage.token, | ||||
| 			{ | ||||
|  | @ -583,7 +551,8 @@ | |||
| 				top_p: $settings?.options?.top_p ?? undefined, | ||||
| 				num_ctx: $settings?.options?.num_ctx ?? undefined, | ||||
| 				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, | ||||
| 				max_tokens: $settings?.options?.num_predict ?? undefined | ||||
| 				max_tokens: $settings?.options?.num_predict ?? undefined, | ||||
| 				docs: docs | ||||
| 			}, | ||||
| 			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` | ||||
| 		); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek