forked from open-webui/open-webui
		
	fix: rag
This commit is contained in:
		
							parent
							
								
									88d324b52d
								
							
						
					
					
						commit
						8df6b137cb
					
				
					 2 changed files with 120 additions and 98 deletions
				
			
		|  | @ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str): | |||
|     template = re.sub(r"\[query\]", query, template) | ||||
| 
 | ||||
|     return template | ||||
| 
 | ||||
| 
 | ||||
| def rag_messages(docs, messages, template, k, embedding_function): | ||||
|     print(docs) | ||||
| 
 | ||||
|     last_user_message_idx = None | ||||
|     for i in range(len(messages) - 1, -1, -1): | ||||
|         if messages[i]["role"] == "user": | ||||
|             last_user_message_idx = i | ||||
|             break | ||||
| 
 | ||||
|     user_message = messages[last_user_message_idx] | ||||
| 
 | ||||
|     if isinstance(user_message["content"], list): | ||||
|         # Handle list content input | ||||
|         content_type = "list" | ||||
|         query = "" | ||||
|         for content_item in user_message["content"]: | ||||
|             if content_item["type"] == "text": | ||||
|                 query = content_item["text"] | ||||
|                 break | ||||
|     elif isinstance(user_message["content"], str): | ||||
|         # Handle text content input | ||||
|         content_type = "text" | ||||
|         query = user_message["content"] | ||||
|     else: | ||||
|         # Fallback in case the input does not match expected types | ||||
|         content_type = None | ||||
|         query = "" | ||||
| 
 | ||||
|     relevant_contexts = [] | ||||
| 
 | ||||
|     for doc in docs: | ||||
|         context = None | ||||
| 
 | ||||
|         try: | ||||
|             if doc["type"] == "collection": | ||||
|                 context = query_collection( | ||||
|                     collection_names=doc["collection_names"], | ||||
|                     query=query, | ||||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|             else: | ||||
|                 context = query_doc( | ||||
|                     collection_name=doc["collection_name"], | ||||
|                     query=query, | ||||
|                     k=k, | ||||
|                     embedding_function=embedding_function, | ||||
|                 ) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             context = None | ||||
| 
 | ||||
|         relevant_contexts.append(context) | ||||
| 
 | ||||
|     context_string = "" | ||||
|     for context in relevant_contexts: | ||||
|         if context: | ||||
|             context_string += " ".join(context["documents"][0]) + "\n" | ||||
| 
 | ||||
|     ra_content = rag_template( | ||||
|         template=template, | ||||
|         context=context_string, | ||||
|         query=query, | ||||
|     ) | ||||
| 
 | ||||
|     if content_type == "list": | ||||
|         new_content = [] | ||||
|         for content_item in user_message["content"]: | ||||
|             if content_item["type"] == "text": | ||||
|                 # Update the text item's content with ra_content | ||||
|                 new_content.append({"type": "text", "text": ra_content}) | ||||
|             else: | ||||
|                 # Keep other types of content as they are | ||||
|                 new_content.append(content_item) | ||||
|         new_user_message = {**user_message, "content": new_content} | ||||
|     else: | ||||
|         new_user_message = { | ||||
|             **user_message, | ||||
|             "content": ra_content, | ||||
|         } | ||||
| 
 | ||||
|     messages[last_user_message_idx] = new_user_message | ||||
| 
 | ||||
|     return messages | ||||
|  |  | |||
							
								
								
									
										132
									
								
								backend/main.py
									
										
									
									
									
								
							
							
						
						
									
										132
									
								
								backend/main.py
									
										
									
									
									
								
							|  | @ -28,7 +28,7 @@ from typing import List | |||
| 
 | ||||
| 
 | ||||
| from utils.utils import get_admin_user | ||||
| from apps.rag.utils import query_doc, query_collection, rag_template | ||||
| from apps.rag.utils import rag_messages | ||||
| 
 | ||||
| from config import ( | ||||
|     WEBUI_NAME, | ||||
|  | @ -60,19 +60,6 @@ app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST | |||
| 
 | ||||
| origins = ["*"] | ||||
| 
 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=origins, | ||||
|     allow_credentials=True, | ||||
|     allow_methods=["*"], | ||||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @app.on_event("startup") | ||||
| async def on_startup(): | ||||
|     await litellm_app_startup() | ||||
| 
 | ||||
| 
 | ||||
| class RAGMiddleware(BaseHTTPMiddleware): | ||||
|     async def dispatch(self, request: Request, call_next): | ||||
|  | @ -91,98 +78,33 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
|             # Example: Add a new key-value pair or modify existing ones | ||||
|             # data["modified"] = True  # Example modification | ||||
|             if "docs" in data: | ||||
|                 docs = data["docs"] | ||||
|                 print(docs) | ||||
| 
 | ||||
|                 last_user_message_idx = None | ||||
|                 for i in range(len(data["messages"]) - 1, -1, -1): | ||||
|                     if data["messages"][i]["role"] == "user": | ||||
|                         last_user_message_idx = i | ||||
|                         break | ||||
| 
 | ||||
|                 user_message = data["messages"][last_user_message_idx] | ||||
| 
 | ||||
|                 if isinstance(user_message["content"], list): | ||||
|                     # Handle list content input | ||||
|                     content_type = "list" | ||||
|                     query = "" | ||||
|                     for content_item in user_message["content"]: | ||||
|                         if content_item["type"] == "text": | ||||
|                             query = content_item["text"] | ||||
|                             break | ||||
|                 elif isinstance(user_message["content"], str): | ||||
|                     # Handle text content input | ||||
|                     content_type = "text" | ||||
|                     query = user_message["content"] | ||||
|                 else: | ||||
|                     # Fallback in case the input does not match expected types | ||||
|                     content_type = None | ||||
|                     query = "" | ||||
| 
 | ||||
|                 relevant_contexts = [] | ||||
| 
 | ||||
|                 for doc in docs: | ||||
|                     context = None | ||||
| 
 | ||||
|                     try: | ||||
|                         if doc["type"] == "collection": | ||||
|                             context = query_collection( | ||||
|                                 collection_names=doc["collection_names"], | ||||
|                                 query=query, | ||||
|                                 k=rag_app.state.TOP_K, | ||||
|                                 embedding_function=rag_app.state.sentence_transformer_ef, | ||||
|                             ) | ||||
|                         else: | ||||
|                             context = query_doc( | ||||
|                                 collection_name=doc["collection_name"], | ||||
|                                 query=query, | ||||
|                                 k=rag_app.state.TOP_K, | ||||
|                                 embedding_function=rag_app.state.sentence_transformer_ef, | ||||
|                             ) | ||||
|                     except Exception as e: | ||||
|                         print(e) | ||||
|                         context = None | ||||
| 
 | ||||
|                     relevant_contexts.append(context) | ||||
| 
 | ||||
|                 context_string = "" | ||||
|                 for context in relevant_contexts: | ||||
|                     if context: | ||||
|                         context_string += " ".join(context["documents"][0]) + "\n" | ||||
| 
 | ||||
|                 ra_content = rag_template( | ||||
|                     template=rag_app.state.RAG_TEMPLATE, | ||||
|                     context=context_string, | ||||
|                     query=query, | ||||
|                 data = {**data} | ||||
|                 data["messages"] = rag_messages( | ||||
|                     data["docs"], | ||||
|                     data["messages"], | ||||
|                     rag_app.state.RAG_TEMPLATE, | ||||
|                     rag_app.state.TOP_K, | ||||
|                     rag_app.state.sentence_transformer_ef, | ||||
|                 ) | ||||
| 
 | ||||
|                 if content_type == "list": | ||||
|                     new_content = [] | ||||
|                     for content_item in user_message["content"]: | ||||
|                         if content_item["type"] == "text": | ||||
|                             # Update the text item's content with ra_content | ||||
|                             new_content.append({"type": "text", "text": ra_content}) | ||||
|                         else: | ||||
|                             # Keep other types of content as they are | ||||
|                             new_content.append(content_item) | ||||
|                     new_user_message = {**user_message, "content": new_content} | ||||
|                 else: | ||||
|                     new_user_message = { | ||||
|                         **user_message, | ||||
|                         "content": ra_content, | ||||
|                     } | ||||
| 
 | ||||
|                 data["messages"][last_user_message_idx] = new_user_message | ||||
|                 del data["docs"] | ||||
| 
 | ||||
|                 print(data["messages"]) | ||||
| 
 | ||||
|             modified_body_bytes = json.dumps(data).encode("utf-8") | ||||
| 
 | ||||
|             # Create a new request with the modified body | ||||
|             scope = request.scope | ||||
|             scope["body"] = modified_body_bytes | ||||
|             request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) | ||||
|             # Replace the request body with the modified one | ||||
|             request._body = modified_body_bytes | ||||
| 
 | ||||
|             # Set custom header to ensure content-length matches new body length | ||||
|             request.headers.__dict__["_list"] = [ | ||||
|                 (b"content-length", str(len(modified_body_bytes)).encode("utf-8")), | ||||
|                 *[ | ||||
|                     (k, v) | ||||
|                     for k, v in request.headers.raw | ||||
|                     if k.lower() != b"content-length" | ||||
|                 ], | ||||
|             ] | ||||
| 
 | ||||
|         response = await call_next(request) | ||||
|         return response | ||||
|  | @ -194,6 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware): | |||
| app.add_middleware(RAGMiddleware) | ||||
| 
 | ||||
| 
 | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=origins, | ||||
|     allow_credentials=True, | ||||
|     allow_methods=["*"], | ||||
|     allow_headers=["*"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| @app.middleware("http") | ||||
| async def check_url(request: Request, call_next): | ||||
|     start_time = int(time.time()) | ||||
|  | @ -204,6 +135,11 @@ async def check_url(request: Request, call_next): | |||
|     return response | ||||
| 
 | ||||
| 
 | ||||
| @app.on_event("startup") | ||||
| async def on_startup(): | ||||
|     await litellm_app_startup() | ||||
| 
 | ||||
| 
 | ||||
| app.mount("/api/v1", webui_app) | ||||
| app.mount("/litellm/api", litellm_app) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek