forked from open-webui/open-webui
		
	feat: doc tagging
This commit is contained in:
		
							parent
							
								
									8fd1b62e04
								
							
						
					
					
						commit
						00803c92f2
					
				
					 10 changed files with 344 additions and 108 deletions
				
			
		|  | @ -128,6 +128,51 @@ class QueryCollectionsForm(BaseModel): | |||
|     k: Optional[int] = 4 | ||||
| 
 | ||||
| 
 | ||||
| def merge_and_sort_query_results(query_results, k): | ||||
|     # Initialize lists to store combined data | ||||
|     combined_ids = [] | ||||
|     combined_distances = [] | ||||
|     combined_metadatas = [] | ||||
|     combined_documents = [] | ||||
| 
 | ||||
|     # Combine data from each dictionary | ||||
|     for data in query_results: | ||||
|         combined_ids.extend(data["ids"][0]) | ||||
|         combined_distances.extend(data["distances"][0]) | ||||
|         combined_metadatas.extend(data["metadatas"][0]) | ||||
|         combined_documents.extend(data["documents"][0]) | ||||
| 
 | ||||
|     # Create a list of tuples (distance, id, metadata, document) | ||||
|     combined = list( | ||||
|         zip(combined_distances, combined_ids, combined_metadatas, combined_documents) | ||||
|     ) | ||||
| 
 | ||||
|     # Sort the list based on distances | ||||
|     combined.sort(key=lambda x: x[0]) | ||||
| 
 | ||||
|     # Unzip the sorted list | ||||
|     sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) | ||||
| 
 | ||||
|     # Slicing the lists to include only k elements | ||||
|     sorted_distances = list(sorted_distances)[:k] | ||||
|     sorted_ids = list(sorted_ids)[:k] | ||||
|     sorted_metadatas = list(sorted_metadatas)[:k] | ||||
|     sorted_documents = list(sorted_documents)[:k] | ||||
| 
 | ||||
|     # Create the output dictionary | ||||
|     merged_query_results = { | ||||
|         "ids": [sorted_ids], | ||||
|         "distances": [sorted_distances], | ||||
|         "metadatas": [sorted_metadatas], | ||||
|         "documents": [sorted_documents], | ||||
|         "embeddings": None, | ||||
|         "uris": None, | ||||
|         "data": None, | ||||
|     } | ||||
| 
 | ||||
|     return merged_query_results | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/query/collections") | ||||
| def query_collections( | ||||
|     form_data: QueryCollectionsForm, | ||||
|  | @ -147,7 +192,7 @@ def query_collections( | |||
|         except: | ||||
|             pass | ||||
| 
 | ||||
|     return results | ||||
|     return merge_and_sort_query_results(results, form_data.k) | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/web") | ||||
|  |  | |||
|  | @ -44,6 +44,16 @@ class DocumentModel(BaseModel): | |||
| #################### | ||||
| 
 | ||||
| 
 | ||||
| class DocumentResponse(BaseModel): | ||||
|     collection_name: str | ||||
|     name: str | ||||
|     title: str | ||||
|     filename: str | ||||
|     content: Optional[dict] = None | ||||
|     user_id: str | ||||
|     timestamp: int  # timestamp in epoch | ||||
| 
 | ||||
| 
 | ||||
| class DocumentUpdateForm(BaseModel): | ||||
|     name: str | ||||
|     title: str | ||||
|  | @ -111,6 +121,26 @@ class DocumentsTable: | |||
|             print(e) | ||||
|             return None | ||||
| 
 | ||||
|     def update_doc_content_by_name( | ||||
|         self, name: str, updated: dict | ||||
|     ) -> Optional[DocumentModel]: | ||||
|         try: | ||||
|             doc = self.get_doc_by_name(name) | ||||
|             doc_content = json.loads(doc.content if doc.content else "{}") | ||||
|             doc_content = {**doc_content, **updated} | ||||
| 
 | ||||
|             query = Document.update( | ||||
|                 content=json.dumps(doc_content), | ||||
|                 timestamp=int(time.time()), | ||||
|             ).where(Document.name == name) | ||||
|             query.execute() | ||||
| 
 | ||||
|             doc = Document.get(Document.name == name) | ||||
|             return DocumentModel(**model_to_dict(doc)) | ||||
|         except Exception as e: | ||||
|             print(e) | ||||
|             return None | ||||
| 
 | ||||
|     def delete_doc_by_name(self, name: str) -> bool: | ||||
|         try: | ||||
|             query = Document.delete().where((Document.name == name)) | ||||
|  |  | |||
|  | @ -11,6 +11,7 @@ from apps.web.models.documents import ( | |||
|     DocumentForm, | ||||
|     DocumentUpdateForm, | ||||
|     DocumentModel, | ||||
|     DocumentResponse, | ||||
| ) | ||||
| 
 | ||||
| from utils.utils import get_current_user | ||||
|  | @ -23,9 +24,18 @@ router = APIRouter() | |||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/", response_model=List[DocumentModel]) | ||||
| @router.get("/", response_model=List[DocumentResponse]) | ||||
| async def get_documents(user=Depends(get_current_user)): | ||||
|     return Documents.get_docs() | ||||
|     docs = [ | ||||
|         DocumentResponse( | ||||
|             **{ | ||||
|                 **doc.model_dump(), | ||||
|                 "content": json.loads(doc.content if doc.content else "{}"), | ||||
|             } | ||||
|         ) | ||||
|         for doc in Documents.get_docs() | ||||
|     ] | ||||
|     return docs | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
|  | @ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)): | |||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/create", response_model=Optional[DocumentModel]) | ||||
| @router.post("/create", response_model=Optional[DocumentResponse]) | ||||
| async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): | ||||
|     if user.role != "admin": | ||||
|         raise HTTPException( | ||||
|  | @ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) | |||
|         doc = Documents.insert_new_doc(user.id, form_data) | ||||
| 
 | ||||
|         if doc: | ||||
|             return doc | ||||
|             return DocumentResponse( | ||||
|                 **{ | ||||
|                     **doc.model_dump(), | ||||
|                     "content": json.loads(doc.content if doc.content else "{}"), | ||||
|                 } | ||||
|             ) | ||||
|         else: | ||||
|             raise HTTPException( | ||||
|                 status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  | @ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) | |||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/name/{name}", response_model=Optional[DocumentModel]) | ||||
| @router.get("/name/{name}", response_model=Optional[DocumentResponse]) | ||||
| async def get_doc_by_name(name: str, user=Depends(get_current_user)): | ||||
|     doc = Documents.get_doc_by_name(name) | ||||
| 
 | ||||
|     if doc: | ||||
|         return doc | ||||
|         return DocumentResponse( | ||||
|             **{ | ||||
|                 **doc.model_dump(), | ||||
|                 "content": json.loads(doc.content if doc.content else "{}"), | ||||
|             } | ||||
|         ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|             detail=ERROR_MESSAGES.NOT_FOUND, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # TagDocByName | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| class TagDocumentForm(BaseModel): | ||||
|     name: str | ||||
|     tags: List[dict] | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/name/{name}/tags", response_model=Optional[DocumentResponse]) | ||||
| async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): | ||||
|     doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) | ||||
| 
 | ||||
|     if doc: | ||||
|         return DocumentResponse( | ||||
|             **{ | ||||
|                 **doc.model_dump(), | ||||
|                 "content": json.loads(doc.content if doc.content else "{}"), | ||||
|             } | ||||
|         ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_401_UNAUTHORIZED, | ||||
|  | @ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)): | |||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.post("/name/{name}/update", response_model=Optional[DocumentModel]) | ||||
| @router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) | ||||
| async def update_doc_by_name( | ||||
|     name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) | ||||
| ): | ||||
|  | @ -94,7 +142,12 @@ async def update_doc_by_name( | |||
| 
 | ||||
|     doc = Documents.update_doc_by_name(name, form_data) | ||||
|     if doc: | ||||
|         return doc | ||||
|         return DocumentResponse( | ||||
|             **{ | ||||
|                 **doc.model_dump(), | ||||
|                 "content": json.loads(doc.content if doc.content else "{}"), | ||||
|             } | ||||
|         ) | ||||
|     else: | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek