From 485236624f25252f895a4f2a799f2e21129429e4 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 1 Feb 2024 13:17:47 -0800 Subject: [PATCH 01/12] feat: drag and drop document anywhere --- src/lib/components/AddFilesPlaceholder.svelte | 8 +- src/routes/(app)/documents/+page.svelte | 282 ++++++++++-------- 2 files changed, 167 insertions(+), 123 deletions(-) diff --git a/src/lib/components/AddFilesPlaceholder.svelte b/src/lib/components/AddFilesPlaceholder.svelte index 7cc51c6f..3cbd0454 100644 --- a/src/lib/components/AddFilesPlaceholder.svelte +++ b/src/lib/components/AddFilesPlaceholder.svelte @@ -1,6 +1,8 @@
📄
Add Files
-
- Drop any files here to add to the conversation -
+
+ Drop any files here to add to the conversation +
+
diff --git a/src/routes/(app)/documents/+page.svelte b/src/routes/(app)/documents/+page.svelte index f5b8e5ec..2c4d82fe 100644 --- a/src/routes/(app)/documents/+page.svelte +++ b/src/routes/(app)/documents/+page.svelte @@ -12,6 +12,7 @@ import { transformFileName } from '$lib/utils'; import EditDocModal from '$lib/components/documents/EditDocModal.svelte'; + import AddFilesPlaceholder from '$lib/components/AddFilesPlaceholder.svelte'; let importFiles = ''; @@ -49,44 +50,94 @@ } }; - const onDragOver = (e) => { - e.preventDefault(); - dragged = true; - }; + onMount(() => { + const dropZone = document.querySelector('body'); - const onDragLeave = () => { - dragged = false; - }; + const onDragOver = (e) => { + e.preventDefault(); + dragged = true; + }; - const onDrop = async (e) => { - e.preventDefault(); - console.log(e); + const onDragLeave = () => { + dragged = false; + }; - if (e.dataTransfer?.files) { - const inputFiles = e.dataTransfer?.files; + const onDrop = async (e) => { + e.preventDefault(); + console.log(e); - if (inputFiles && inputFiles.length > 0) { - const file = inputFiles[0]; - if ( - SUPPORTED_FILE_TYPE.includes(file['type']) || - SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) - ) { - uploadDoc(file); + if (e.dataTransfer?.files) { + let reader = new FileReader(); + + reader.onload = (event) => { + files = [ + ...files, + { + type: 'image', + url: `${event.target.result}` + } + ]; + }; + + const inputFiles = e.dataTransfer?.files; + + if (inputFiles && inputFiles.length > 0) { + const file = inputFiles[0]; + console.log(file, file.name.split('.').at(-1)); + if (['image/gif', 'image/jpeg', 'image/png'].includes(file['type'])) { + reader.readAsDataURL(file); + } else if ( + SUPPORTED_FILE_TYPE.includes(file['type']) || + SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) + ) { + uploadDoc(file); + } else { + toast.error( + `Unknown File Type '${file['type']}', but accepting and treating as plain text` + ); + uploadDoc(file); + } } else { - toast.error( - `Unknown File Type '${file['type']}', but accepting and treating as plain text` - ); - uploadDoc(file); + toast.error(`File not found.`); } - } else { - toast.error(`File not found.`); } - } - dragged = false; - }; + dragged = false; + }; + + dropZone?.addEventListener('dragover', onDragOver); + dropZone?.addEventListener('drop', onDrop); + dropZone?.addEventListener('dragleave', onDragLeave); + + return () => { + dropZone?.removeEventListener('dragover', onDragOver); + dropZone?.removeEventListener('drop', onDrop); + dropZone?.removeEventListener('dragleave', onDragLeave); + }; + }); +{#if dragged} +
+
+
+
+ +
+ Drop any files here to add to my documents +
+
+
+
+
+
+{/if} + {#key selectedDoc} {/key} @@ -170,7 +221,7 @@ }} /> -
+ {#each $documents.filter((p) => query === '' || p.name.includes(query)) as doc}
@@ -330,106 +381,97 @@
{/each} - {#if $documents.length != 0} -
-
-
- { - console.log(importFiles); +
- const reader = new FileReader(); - reader.onload = async (event) => { - const savedDocs = JSON.parse(event.target.result); - console.log(savedDocs); +
+
+ { + console.log(importFiles); - for (const doc of savedDocs) { - await createNewDoc( - localStorage.token, - doc.collection_name, - doc.filename, - doc.name, - doc.title - ).catch((error) => { - toast.error(error); - return null; - }); - } + const reader = new FileReader(); + reader.onload = async (event) => { + const savedDocs = JSON.parse(event.target.result); + console.log(savedDocs); - await documents.set(await getDocs(localStorage.token)); - }; + for (const doc of savedDocs) { + await createNewDoc( + localStorage.token, + doc.collection_name, + doc.filename, + doc.name, + doc.title + ).catch((error) => { + toast.error(error); + return null; + }); + } - reader.readAsText(importFiles[0]); - }} - /> + await documents.set(await getDocs(localStorage.token)); + }; - - - - - -
+
Import Documents Mapping
+ +
+ + + +
+ + +
- {/if} +
From 50f7b20ac293b2f17de0f382bbcd7ae5d0f89349 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Thu, 1 Feb 2024 13:35:41 -0800 Subject: [PATCH 02/12] refac --- backend/apps/rag/main.py | 46 +++++++++++++++++++++++----- src/lib/apis/rag/index.ts | 31 +++++++++---------- src/routes/(app)/+page.svelte | 36 ++++++++++++---------- src/routes/(app)/c/[id]/+page.svelte | 36 ++++++++++++---------- 4 files changed, 91 insertions(+), 58 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 85bc995a..eec3dfa2 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -10,6 +10,7 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware import os, shutil +from typing import List # from chromadb.utils import embedding_functions @@ -96,19 +97,22 @@ async def get_status(): return {"status": True} -@app.get("/query/{collection_name}") +class QueryCollectionForm(BaseModel): + collection_name: str + query: str + k: Optional[int] = 4 + + +@app.post("/query/collection") def query_collection( - collection_name: str, - query: str, - k: Optional[int] = 4, + form_data: QueryCollectionForm, user=Depends(get_current_user), ): try: collection = CHROMA_CLIENT.get_collection( - name=collection_name, + name=form_data.collection_name, ) - result = collection.query(query_texts=[query], n_results=k) - + result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: print(e) @@ -118,6 +122,34 @@ def query_collection( ) +class QueryCollectionsForm(BaseModel): + collection_names: List[str] + query: str + k: Optional[int] = 4 + + +@app.post("/query/collections") +def query_collections( + form_data: QueryCollectionsForm, + user=Depends(get_current_user), +): + results = [] + + for collection_name in form_data.collection_names: + try: + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query( + query_texts=[form_data.query], n_results=form_data.k + ) + results.append(result) + except: + pass + + return results + + @app.post("/web") def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 8a2b8cb4..e2656943 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string export const queryVectorDB = async ( token: string, - collection_name: string, + collection_names: string[], query: string, k: number ) => { let error = null; - const searchParams = new URLSearchParams(); - searchParams.set('query', query); - if (k) { - searchParams.set('k', k.toString()); - } - - const res = await fetch( - `${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, - { - method: 'GET', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - } - ) + const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_names: collection_names, + query: query, + k: k + }) + }) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index eebf7743..657c9d9f 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -232,26 +232,28 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); - }) - ); - relevantContexts = relevantContexts.filter((context) => context); + let relevantContexts = await queryVectorDB( + localStorage.token, + docs.map((d) => d.collection_name), + query, + 4 + ).catch((error) => { + console.log(error); + return null; + }); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + if (relevantContexts) { + relevantContexts = relevantContexts.filter((context) => context); - console.log(contextString); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; + console.log(contextString); + + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; + } await tick(); processing = ''; } diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index c161435d..5509019a 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -246,26 +246,28 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); - }) - ); - relevantContexts = relevantContexts.filter((context) => context); + let relevantContexts = await queryVectorDB( + localStorage.token, + docs.map((d) => d.collection_name), + query, + 4 + ).catch((error) => { + console.log(error); + return null; + }); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + if (relevantContexts) { + relevantContexts = relevantContexts.filter((context) => context); - console.log(contextString); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; + console.log(contextString); + + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; + } await tick(); processing = ''; } From 1d0eaec37e490bac793f6ca089af44ed736403ba Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 2 Feb 2024 22:57:18 -0800 Subject: [PATCH 03/12] refac: queryVectorDB renamed to queryCollection --- src/lib/apis/rag/index.ts | 6 ++--- src/routes/(app)/+page.svelte | 38 +++++++++++++--------------- src/routes/(app)/c/[id]/+page.svelte | 38 +++++++++++++--------------- 3 files changed, 39 insertions(+), 43 deletions(-) diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index e2656943..08dff5bf 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -66,13 +66,13 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string export const queryVectorDB = async ( token: string, - collection_names: string[], + collection_name: string, query: string, k: number ) => { let error = null; - const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, { + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { method: 'POST', headers: { Accept: 'application/json', @@ -80,7 +80,7 @@ export const queryVectorDB = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - collection_names: collection_names, + collection_name: collection_name, query: query, k: k }) diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 657c9d9f..956b6cb0 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -28,7 +28,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryVectorDB } from '$lib/apis/rag'; + import { queryCollection } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -232,28 +232,26 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await queryVectorDB( - localStorage.token, - docs.map((d) => d.collection_name), - query, - 4 - ).catch((error) => { - console.log(error); - return null; - }); + let relevantContexts = await Promise.all( + docs.map(async (doc) => { + return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + }) + ); + relevantContexts = relevantContexts.filter((context) => context); - if (relevantContexts) { - relevantContexts = relevantContexts.filter((context) => context); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + console.log(contextString); - console.log(contextString); - - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; - } + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; await tick(); processing = ''; } diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 5509019a..fac8a01c 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -29,7 +29,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryVectorDB } from '$lib/apis/rag'; + import { queryCollection } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -246,28 +246,26 @@ processing = 'Reading'; const query = history.messages[parentId].content; - let relevantContexts = await queryVectorDB( - localStorage.token, - docs.map((d) => d.collection_name), - query, - 4 - ).catch((error) => { - console.log(error); - return null; - }); + let relevantContexts = await Promise.all( + docs.map(async (doc) => { + return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + }) + ); + relevantContexts = relevantContexts.filter((context) => context); - if (relevantContexts) { - relevantContexts = relevantContexts.filter((context) => context); + const contextString = relevantContexts.reduce((a, context, i, arr) => { + return `${a}${context.documents.join(' ')}\n`; + }, ''); - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); + console.log(contextString); - console.log(contextString); - - history.messages[parentId].raContent = RAGTemplate(contextString, query); - history.messages[parentId].contexts = relevantContexts; - } + history.messages[parentId].raContent = RAGTemplate(contextString, query); + history.messages[parentId].contexts = relevantContexts; await tick(); processing = ''; } From 8fd1b62e04a09f21e56d1e07e46b145067871a84 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 2 Feb 2024 22:59:36 -0800 Subject: [PATCH 04/12] fix: api function name --- src/lib/apis/rag/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 08dff5bf..ca14371f 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string return res; }; -export const queryVectorDB = async ( +export const queryCollection = async ( token: string, collection_name: string, query: string, From 00803c92f2e4c53e8e8e827252d93ea26cd47a75 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 3 Feb 2024 14:44:49 -0800 Subject: [PATCH 05/12] feat: doc tagging --- backend/apps/rag/main.py | 47 ++++++++- backend/apps/web/models/documents.py | 30 ++++++ backend/apps/web/routers/documents.py | 69 +++++++++++-- src/lib/apis/documents/index.ts | 41 ++++++++ src/lib/components/common/Tags.svelte | 24 +++++ .../components/common/Tags/TagInput.svelte | 64 ++++++++++++ src/lib/components/common/Tags/TagList.svelte | 33 +++++++ .../components/documents/EditDocModal.svelte | 44 ++++++++- src/lib/components/layout/Navbar.svelte | 99 +------------------ src/routes/(app)/documents/+page.svelte | 1 - 10 files changed, 344 insertions(+), 108 deletions(-) create mode 100644 src/lib/components/common/Tags.svelte create mode 100644 src/lib/components/common/Tags/TagInput.svelte create mode 100644 src/lib/components/common/Tags/TagList.svelte diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index eec3dfa2..de00a581 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -128,6 +128,51 @@ class QueryCollectionsForm(BaseModel): k: Optional[int] = 4 +def merge_and_sort_query_results(query_results, k): + # Initialize lists to store combined data + combined_ids = [] + combined_distances = [] + combined_metadatas = [] + combined_documents = [] + + # Combine data from each dictionary + for data in query_results: + combined_ids.extend(data["ids"][0]) + combined_distances.extend(data["distances"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_documents.extend(data["documents"][0]) + + # Create a list of tuples (distance, id, metadata, document) + combined = list( + zip(combined_distances, combined_ids, combined_metadatas, combined_documents) + ) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0]) + + # Unzip the sorted list + sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_ids = list(sorted_ids)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + sorted_documents = list(sorted_documents)[:k] + + # Create the output dictionary + merged_query_results = { + "ids": [sorted_ids], + "distances": [sorted_distances], + "metadatas": [sorted_metadatas], + "documents": [sorted_documents], + "embeddings": None, + "uris": None, + "data": None, + } + + return merged_query_results + + @app.post("/query/collections") def query_collections( form_data: QueryCollectionsForm, @@ -147,7 +192,7 @@ def query_collections( except: pass - return results + return merge_and_sort_query_results(results, form_data.k) @app.post("/web") diff --git a/backend/apps/web/models/documents.py b/backend/apps/web/models/documents.py index 0196c38b..6a372b2c 100644 --- a/backend/apps/web/models/documents.py +++ b/backend/apps/web/models/documents.py @@ -44,6 +44,16 @@ class DocumentModel(BaseModel): #################### +class DocumentResponse(BaseModel): + collection_name: str + name: str + title: str + filename: str + content: Optional[dict] = None + user_id: str + timestamp: int # timestamp in epoch + + class DocumentUpdateForm(BaseModel): name: str title: str @@ -111,6 +121,26 @@ class DocumentsTable: print(e) return None + def update_doc_content_by_name( + self, name: str, updated: dict + ) -> Optional[DocumentModel]: + try: + doc = self.get_doc_by_name(name) + doc_content = json.loads(doc.content if doc.content else "{}") + doc_content = {**doc_content, **updated} + + query = Document.update( + content=json.dumps(doc_content), + timestamp=int(time.time()), + ).where(Document.name == name) + query.execute() + + doc = Document.get(Document.name == name) + return DocumentModel(**model_to_dict(doc)) + except Exception as e: + print(e) + return None + def delete_doc_by_name(self, name: str) -> bool: try: query = Document.delete().where((Document.name == name)) diff --git a/backend/apps/web/routers/documents.py b/backend/apps/web/routers/documents.py index c64ee8f0..3b6434d1 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/web/routers/documents.py @@ -11,6 +11,7 @@ from apps.web.models.documents import ( DocumentForm, DocumentUpdateForm, DocumentModel, + DocumentResponse, ) from utils.utils import get_current_user @@ -23,9 +24,18 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[DocumentModel]) +@router.get("/", response_model=List[DocumentResponse]) async def get_documents(user=Depends(get_current_user)): - return Documents.get_docs() + docs = [ + DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) + for doc in Documents.get_docs() + ] + return docs ############################ @@ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)): ############################ -@router.post("/create", response_model=Optional[DocumentModel]) +@router.post("/create", response_model=Optional[DocumentResponse]) async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): if user.role != "admin": raise HTTPException( @@ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) doc = Documents.insert_new_doc(user.id, form_data) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) ############################ -@router.get("/name/{name}", response_model=Optional[DocumentModel]) +@router.get("/name/{name}", response_model=Optional[DocumentResponse]) async def get_doc_by_name(name: str, user=Depends(get_current_user)): doc = Documents.get_doc_by_name(name) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# TagDocByName +############################ + + +class TagDocumentForm(BaseModel): + name: str + tags: List[dict] + + +@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse]) +async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): + doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) + + if doc: + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)): ############################ -@router.post("/name/{name}/update", response_model=Optional[DocumentModel]) +@router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) async def update_doc_by_name( name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) ): @@ -94,7 +142,12 @@ async def update_doc_by_name( doc = Documents.update_doc_by_name(name, form_data) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/lib/apis/documents/index.ts b/src/lib/apis/documents/index.ts index fb208ea4..2f7fb2b9 100644 --- a/src/lib/apis/documents/index.ts +++ b/src/lib/apis/documents/index.ts @@ -144,6 +144,47 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda return res; }; +type TagDocForm = { + name: string; + tags: string[]; +}; + +export const tagDocByName = async (token: string, name: string, form: TagDocForm) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/documents/name/${name}/tags`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: form.name, + tags: form.tags + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteDocByName = async (token: string, name: string) => { let error = null; diff --git a/src/lib/components/common/Tags.svelte b/src/lib/components/common/Tags.svelte new file mode 100644 index 00000000..426678fb --- /dev/null +++ b/src/lib/components/common/Tags.svelte @@ -0,0 +1,24 @@ + + +
+ { + deleteTag(e.detail); + }} + /> + + { + addTag(e.detail); + }} + /> +
diff --git a/src/lib/components/common/Tags/TagInput.svelte b/src/lib/components/common/Tags/TagInput.svelte new file mode 100644 index 00000000..64726cf8 --- /dev/null +++ b/src/lib/components/common/Tags/TagInput.svelte @@ -0,0 +1,64 @@ + + +
+ {#if showTagInput} +
+ + + +
+ + + {/if} + + +
diff --git a/src/lib/components/common/Tags/TagList.svelte b/src/lib/components/common/Tags/TagList.svelte new file mode 100644 index 00000000..ea918ca5 --- /dev/null +++ b/src/lib/components/common/Tags/TagList.svelte @@ -0,0 +1,33 @@ + + +{#each tags as tag} +
+
+ {tag.name} +
+ +
+{/each} diff --git a/src/lib/components/documents/EditDocModal.svelte b/src/lib/components/documents/EditDocModal.svelte index 35ea1539..43b90e02 100644 --- a/src/lib/components/documents/EditDocModal.svelte +++ b/src/lib/components/documents/EditDocModal.svelte @@ -3,16 +3,22 @@ import dayjs from 'dayjs'; import { onMount } from 'svelte'; - import { getDocs, updateDocByName } from '$lib/apis/documents'; + import { getDocs, tagDocByName, updateDocByName } from '$lib/apis/documents'; import Modal from '../common/Modal.svelte'; import { documents } from '$lib/stores'; + import TagInput from '../common/Tags/TagInput.svelte'; + import Tags from '../common/Tags.svelte'; + import { addTagById } from '$lib/apis/chats'; export let show = false; export let selectedDoc; + let tags = []; + let doc = { name: '', - title: '' + title: '', + content: null }; const submitHandler = async () => { @@ -30,9 +36,37 @@ } }; + const addTagHandler = async (tagName) => { + if (!tags.find((tag) => tag.name === tagName)) { + tags = [...tags, { name: tagName }]; + + await tagDocByName(localStorage.token, doc.name, { + name: doc.name, + tags: tags + }); + + documents.set(await getDocs(localStorage.token)); + } else { + console.log('tag already exists'); + } + }; + + const deleteTagHandler = async (tagName) => { + tags = tags.filter((tag) => tag.name !== tagName); + + await tagDocByName(localStorage.token, doc.name, { + name: doc.name, + tags: tags + }); + + documents.set(await getDocs(localStorage.token)); + }; + onMount(() => { if (selectedDoc) { doc = JSON.parse(JSON.stringify(selectedDoc)); + + tags = doc?.content?.tags ?? []; } }); @@ -112,6 +146,12 @@ />
+ +
+
Tags
+ + +
diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index 3fc0ffa0..e829aeb3 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -6,6 +6,8 @@ import { getChatById } from '$lib/apis/chats'; import { chatId, modelfiles } from '$lib/stores'; import ShareChatModal from '../chat/ShareChatModal.svelte'; + import TagInput from '../common/Tags/TagInput.svelte'; + import Tags from '../common/Tags.svelte'; export let initNewChat: Function; export let title: string = 'Ollama Web UI'; @@ -61,21 +63,6 @@ saveAs(blob, `chat-${chat.title}.txt`); }; - - const addTagHandler = () => { - // if (!tags.find((e) => e.name === tagName)) { - // tags = [ - // ...tags, - // { - // name: JSON.parse(JSON.stringify(tagName)) - // } - // ]; - // } - - addTag(tagName); - tagName = ''; - showTagInput = false; - }; @@ -116,87 +103,7 @@
{#if shareEnabled} -
- {#each tags as tag} -
-
- {tag.name} -
- -
- {/each} - -
- {#if showTagInput} -
- - - -
- - - {/if} - - -
-
+
+
- { - if (inputFiles && inputFiles.length > 0) { - const file = inputFiles[0]; - if ( - SUPPORTED_FILE_TYPE.includes(file['type']) || - SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) - ) { - uploadDoc(file); - } else { - toast.error( - `Unknown File Type '${file['type']}', but accepting and treating as plain text` - ); - uploadDoc(file); - } - - inputFiles = null; - e.target.value = ''; - } else { - toast.error(`File not found.`); - } - }} - /> + {#if tags.length > 0} +
+ + {#each tags as tag} + + {/each} +
+ {/if} {#each $documents.filter((p) => query === '' || p.name.includes(query)) as doc} -
-
+
From f448a4b385af9020b9c070fecd9eae6cd440f631 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 3 Feb 2024 15:17:00 -0800 Subject: [PATCH 07/12] feat: doc filter by tag --- src/routes/(app)/documents/+page.svelte | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/routes/(app)/documents/+page.svelte b/src/routes/(app)/documents/+page.svelte index 898c5f5f..68ca044d 100644 --- a/src/routes/(app)/documents/+page.svelte +++ b/src/routes/(app)/documents/+page.svelte @@ -22,6 +22,7 @@ let showEditDocModal = false; let selectedDoc; + let selectedTag = ''; let dragged = false; @@ -233,6 +234,7 @@
--> - {#each $documents.filter((p) => query === '' || p.name.includes(query)) as doc} + {#each $documents.filter((doc) => (selectedTag === '' || (doc?.content?.tags ?? []) + .map((tag) => tag.name) + .includes(selectedTag)) && (query === '' || doc.name.includes(query))) as doc}
From 7d2f788a3b2f620a0f99bfe60fe011e060ffb14f Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 3 Feb 2024 15:48:44 -0800 Subject: [PATCH 08/12] feat: import collection from chat input --- src/lib/components/chat/MessageInput.svelte | 30 +++++++++++- .../chat/MessageInput/Documents.svelte | 48 +++++++++++++++---- 2 files changed, 69 insertions(+), 9 deletions(-) diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 2aec58d4..3ba51a6f 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -295,7 +295,7 @@ files = [ ...files, { - type: 'doc', + type: e?.detail?.type ?? 'doc', ...e.detail, upload_status: true } @@ -446,6 +446,34 @@
Document
+ {:else if file.type === 'collection'} +
+
+ + + + +
+ +
+
+ #{file.name} +
+ +
Collection
+
+
{/if}
diff --git a/src/lib/components/chat/MessageInput/Documents.svelte b/src/lib/components/chat/MessageInput/Documents.svelte index 5f252b3d..6cc7bf4d 100644 --- a/src/lib/components/chat/MessageInput/Documents.svelte +++ b/src/lib/components/chat/MessageInput/Documents.svelte @@ -10,12 +10,35 @@ const dispatch = createEventDispatcher(); let selectedIdx = 0; + + let filteredItems = []; let filteredDocs = []; + let filteredTags = []; + + let collections = []; + + $: collections = $documents + .reduce((a, e, i, arr) => { + return [...new Set([...a, ...(e?.content?.tags ?? []).map((tag) => tag.name)])]; + }, []) + .map((tag) => ({ + name: tag, + type: 'collection', + collection_names: $documents + .filter((doc) => (doc?.content?.tags ?? []).map((tag) => tag.name).includes(tag)) + .map((doc) => doc.collection_name) + })); + + $: filteredCollections = collections + .filter((tag) => tag.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')) + .sort((a, b) => a.name.localeCompare(b.name)); $: filteredDocs = $documents .filter((p) => p.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')) .sort((a, b) => a.title.localeCompare(b.title)); + $: filteredItems = [...filteredCollections, ...filteredDocs]; + $: if (prompt) { selectedIdx = 0; } @@ -25,7 +48,7 @@ }; export const selectDown = () => { - selectedIdx = Math.min(selectedIdx + 1, filteredDocs.length - 1); + selectedIdx = Math.min(selectedIdx + 1, filteredItems.length - 1); }; const confirmSelect = async (doc) => { @@ -60,7 +83,7 @@
- {#each filteredDocs as doc, docIdx} + {#each filteredItems as doc, docIdx} {/each} From 683650ec00f6f619de4fb31ba687da783de499ec Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 3 Feb 2024 15:57:06 -0800 Subject: [PATCH 09/12] feat: collection rag integration --- backend/apps/rag/main.py | 12 +++--- src/lib/apis/rag/index.ts | 41 ++++++++++++++++++- .../chat/Messages/UserMessage.svelte | 29 +++++++++++++ src/routes/(app)/+page.svelte | 27 ++++++++---- src/routes/(app)/c/[id]/+page.svelte | 27 ++++++++---- 5 files changed, 112 insertions(+), 24 deletions(-) diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index de00a581..95535274 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -97,15 +97,15 @@ async def get_status(): return {"status": True} -class QueryCollectionForm(BaseModel): +class QueryDocForm(BaseModel): collection_name: str query: str k: Optional[int] = 4 -@app.post("/query/collection") -def query_collection( - form_data: QueryCollectionForm, +@app.post("/query/doc") +def query_doc( + form_data: QueryDocForm, user=Depends(get_current_user), ): try: @@ -173,8 +173,8 @@ def merge_and_sort_query_results(query_results, k): return merged_query_results -@app.post("/query/collections") -def query_collections( +@app.post("/query/collection") +def query_collection( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ca14371f..3f4f30bf 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string return res; }; -export const queryCollection = async ( +export const queryDoc = async ( token: string, collection_name: string, query: string, @@ -72,6 +72,43 @@ export const queryCollection = async ( ) => { let error = null; + const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_name: collection_name, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const queryCollection = async ( + token: string, + collection_names: string, + query: string, + k: number +) => { + let error = null; + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { method: 'POST', headers: { @@ -80,7 +117,7 @@ export const queryCollection = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - collection_name: collection_name, + collection_names: collection_names, query: query, k: k }) diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte index 761ba41c..0e0fc332 100644 --- a/src/lib/components/chat/Messages/UserMessage.svelte +++ b/src/lib/components/chat/Messages/UserMessage.svelte @@ -117,6 +117,35 @@
Document
+ {:else if file.type === 'collection'} + {/if}
{/each} diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 956b6cb0..376c4e37 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -28,7 +28,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryCollection } from '$lib/apis/rag'; + import { queryCollection, queryDoc } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -224,7 +224,9 @@ const docs = messages .filter((message) => message?.files ?? null) - .map((message) => message.files.filter((item) => item.type === 'doc')) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) .flat(1); console.log(docs); @@ -234,12 +236,21 @@ let relevantContexts = await Promise.all( docs.map(async (doc) => { - return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); + if (doc.type === 'collection') { + return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } else { + return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } }) ); relevantContexts = relevantContexts.filter((context) => context); diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index fac8a01c..83e72c62 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -29,7 +29,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryCollection } from '$lib/apis/rag'; + import { queryCollection, queryDoc } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -238,7 +238,9 @@ const docs = messages .filter((message) => message?.files ?? null) - .map((message) => message.files.filter((item) => item.type === 'doc')) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) .flat(1); console.log(docs); @@ -248,12 +250,21 @@ let relevantContexts = await Promise.all( docs.map(async (doc) => { - return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); + if (doc.type === 'collection') { + return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } else { + return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } }) ); relevantContexts = relevantContexts.filter((context) => context); From 38584856754bac0e7b87ca2e02f99ca441a663c3 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 3 Feb 2024 17:06:59 -0800 Subject: [PATCH 10/12] fix: styling --- src/lib/components/common/Tags/TagList.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/components/common/Tags/TagList.svelte b/src/lib/components/common/Tags/TagList.svelte index ea918ca5..66a0b060 100644 --- a/src/lib/components/common/Tags/TagList.svelte +++ b/src/lib/components/common/Tags/TagList.svelte @@ -9,7 +9,7 @@
-
+
{tag.name}