From 00803c92f2e4c53e8e8e827252d93ea26cd47a75 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 3 Feb 2024 14:44:49 -0800 Subject: [PATCH] feat: doc tagging --- backend/apps/rag/main.py | 47 ++++++++- backend/apps/web/models/documents.py | 30 ++++++ backend/apps/web/routers/documents.py | 69 +++++++++++-- src/lib/apis/documents/index.ts | 41 ++++++++ src/lib/components/common/Tags.svelte | 24 +++++ .../components/common/Tags/TagInput.svelte | 64 ++++++++++++ src/lib/components/common/Tags/TagList.svelte | 33 +++++++ .../components/documents/EditDocModal.svelte | 44 ++++++++- src/lib/components/layout/Navbar.svelte | 99 +------------------ src/routes/(app)/documents/+page.svelte | 1 - 10 files changed, 344 insertions(+), 108 deletions(-) create mode 100644 src/lib/components/common/Tags.svelte create mode 100644 src/lib/components/common/Tags/TagInput.svelte create mode 100644 src/lib/components/common/Tags/TagList.svelte diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index eec3dfa2..de00a581 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -128,6 +128,51 @@ class QueryCollectionsForm(BaseModel): k: Optional[int] = 4 +def merge_and_sort_query_results(query_results, k): + # Initialize lists to store combined data + combined_ids = [] + combined_distances = [] + combined_metadatas = [] + combined_documents = [] + + # Combine data from each dictionary + for data in query_results: + combined_ids.extend(data["ids"][0]) + combined_distances.extend(data["distances"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_documents.extend(data["documents"][0]) + + # Create a list of tuples (distance, id, metadata, document) + combined = list( + zip(combined_distances, combined_ids, combined_metadatas, combined_documents) + ) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0]) + + # Unzip the sorted list + sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_ids = list(sorted_ids)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + sorted_documents = list(sorted_documents)[:k] + + # Create the output dictionary + merged_query_results = { + "ids": [sorted_ids], + "distances": [sorted_distances], + "metadatas": [sorted_metadatas], + "documents": [sorted_documents], + "embeddings": None, + "uris": None, + "data": None, + } + + return merged_query_results + + @app.post("/query/collections") def query_collections( form_data: QueryCollectionsForm, @@ -147,7 +192,7 @@ def query_collections( except: pass - return results + return merge_and_sort_query_results(results, form_data.k) @app.post("/web") diff --git a/backend/apps/web/models/documents.py b/backend/apps/web/models/documents.py index 0196c38b..6a372b2c 100644 --- a/backend/apps/web/models/documents.py +++ b/backend/apps/web/models/documents.py @@ -44,6 +44,16 @@ class DocumentModel(BaseModel): #################### +class DocumentResponse(BaseModel): + collection_name: str + name: str + title: str + filename: str + content: Optional[dict] = None + user_id: str + timestamp: int # timestamp in epoch + + class DocumentUpdateForm(BaseModel): name: str title: str @@ -111,6 +121,26 @@ class DocumentsTable: print(e) return None + def update_doc_content_by_name( + self, name: str, updated: dict + ) -> Optional[DocumentModel]: + try: + doc = self.get_doc_by_name(name) + doc_content = json.loads(doc.content if doc.content else "{}") + doc_content = {**doc_content, **updated} + + query = Document.update( + content=json.dumps(doc_content), + timestamp=int(time.time()), + ).where(Document.name == name) + query.execute() + + doc = Document.get(Document.name == name) + return DocumentModel(**model_to_dict(doc)) + except Exception as e: + print(e) + return None + def delete_doc_by_name(self, name: str) -> bool: try: query = Document.delete().where((Document.name == name)) diff --git a/backend/apps/web/routers/documents.py b/backend/apps/web/routers/documents.py index c64ee8f0..3b6434d1 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/web/routers/documents.py @@ -11,6 +11,7 @@ from apps.web.models.documents import ( DocumentForm, DocumentUpdateForm, DocumentModel, + DocumentResponse, ) from utils.utils import get_current_user @@ -23,9 +24,18 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[DocumentModel]) +@router.get("/", response_model=List[DocumentResponse]) async def get_documents(user=Depends(get_current_user)): - return Documents.get_docs() + docs = [ + DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) + for doc in Documents.get_docs() + ] + return docs ############################ @@ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)): ############################ -@router.post("/create", response_model=Optional[DocumentModel]) +@router.post("/create", response_model=Optional[DocumentResponse]) async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): if user.role != "admin": raise HTTPException( @@ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) doc = Documents.insert_new_doc(user.id, form_data) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) ############################ -@router.get("/name/{name}", response_model=Optional[DocumentModel]) +@router.get("/name/{name}", response_model=Optional[DocumentResponse]) async def get_doc_by_name(name: str, user=Depends(get_current_user)): doc = Documents.get_doc_by_name(name) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# TagDocByName +############################ + + +class TagDocumentForm(BaseModel): + name: str + tags: List[dict] + + +@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse]) +async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): + doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) + + if doc: + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)): ############################ -@router.post("/name/{name}/update", response_model=Optional[DocumentModel]) +@router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) async def update_doc_by_name( name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) ): @@ -94,7 +142,12 @@ async def update_doc_by_name( doc = Documents.update_doc_by_name(name, form_data) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/lib/apis/documents/index.ts b/src/lib/apis/documents/index.ts index fb208ea4..2f7fb2b9 100644 --- a/src/lib/apis/documents/index.ts +++ b/src/lib/apis/documents/index.ts @@ -144,6 +144,47 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda return res; }; +type TagDocForm = { + name: string; + tags: string[]; +}; + +export const tagDocByName = async (token: string, name: string, form: TagDocForm) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/documents/name/${name}/tags`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: form.name, + tags: form.tags + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteDocByName = async (token: string, name: string) => { let error = null; diff --git a/src/lib/components/common/Tags.svelte b/src/lib/components/common/Tags.svelte new file mode 100644 index 00000000..426678fb --- /dev/null +++ b/src/lib/components/common/Tags.svelte @@ -0,0 +1,24 @@ + + +
+ { + deleteTag(e.detail); + }} + /> + + { + addTag(e.detail); + }} + /> +
diff --git a/src/lib/components/common/Tags/TagInput.svelte b/src/lib/components/common/Tags/TagInput.svelte new file mode 100644 index 00000000..64726cf8 --- /dev/null +++ b/src/lib/components/common/Tags/TagInput.svelte @@ -0,0 +1,64 @@ + + +
+ {#if showTagInput} +
+ + + +
+ + + {/if} + + +
diff --git a/src/lib/components/common/Tags/TagList.svelte b/src/lib/components/common/Tags/TagList.svelte new file mode 100644 index 00000000..ea918ca5 --- /dev/null +++ b/src/lib/components/common/Tags/TagList.svelte @@ -0,0 +1,33 @@ + + +{#each tags as tag} +
+
+ {tag.name} +
+ +
+{/each} diff --git a/src/lib/components/documents/EditDocModal.svelte b/src/lib/components/documents/EditDocModal.svelte index 35ea1539..43b90e02 100644 --- a/src/lib/components/documents/EditDocModal.svelte +++ b/src/lib/components/documents/EditDocModal.svelte @@ -3,16 +3,22 @@ import dayjs from 'dayjs'; import { onMount } from 'svelte'; - import { getDocs, updateDocByName } from '$lib/apis/documents'; + import { getDocs, tagDocByName, updateDocByName } from '$lib/apis/documents'; import Modal from '../common/Modal.svelte'; import { documents } from '$lib/stores'; + import TagInput from '../common/Tags/TagInput.svelte'; + import Tags from '../common/Tags.svelte'; + import { addTagById } from '$lib/apis/chats'; export let show = false; export let selectedDoc; + let tags = []; + let doc = { name: '', - title: '' + title: '', + content: null }; const submitHandler = async () => { @@ -30,9 +36,37 @@ } }; + const addTagHandler = async (tagName) => { + if (!tags.find((tag) => tag.name === tagName)) { + tags = [...tags, { name: tagName }]; + + await tagDocByName(localStorage.token, doc.name, { + name: doc.name, + tags: tags + }); + + documents.set(await getDocs(localStorage.token)); + } else { + console.log('tag already exists'); + } + }; + + const deleteTagHandler = async (tagName) => { + tags = tags.filter((tag) => tag.name !== tagName); + + await tagDocByName(localStorage.token, doc.name, { + name: doc.name, + tags: tags + }); + + documents.set(await getDocs(localStorage.token)); + }; + onMount(() => { if (selectedDoc) { doc = JSON.parse(JSON.stringify(selectedDoc)); + + tags = doc?.content?.tags ?? []; } }); @@ -112,6 +146,12 @@ /> + +
+
Tags
+ + +
diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index 3fc0ffa0..e829aeb3 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -6,6 +6,8 @@ import { getChatById } from '$lib/apis/chats'; import { chatId, modelfiles } from '$lib/stores'; import ShareChatModal from '../chat/ShareChatModal.svelte'; + import TagInput from '../common/Tags/TagInput.svelte'; + import Tags from '../common/Tags.svelte'; export let initNewChat: Function; export let title: string = 'Ollama Web UI'; @@ -61,21 +63,6 @@ saveAs(blob, `chat-${chat.title}.txt`); }; - - const addTagHandler = () => { - // if (!tags.find((e) => e.name === tagName)) { - // tags = [ - // ...tags, - // { - // name: JSON.parse(JSON.stringify(tagName)) - // } - // ]; - // } - - addTag(tagName); - tagName = ''; - showTagInput = false; - }; @@ -116,87 +103,7 @@
{#if shareEnabled} -
- {#each tags as tag} -
-
- {tag.name} -
- -
- {/each} - -
- {#if showTagInput} -
- - - -
- - - {/if} - - -
-
+