diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 85bc995a..95535274 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -10,6 +10,7 @@ from fastapi import ( ) from fastapi.middleware.cors import CORSMiddleware import os, shutil +from typing import List # from chromadb.utils import embedding_functions @@ -96,19 +97,22 @@ async def get_status(): return {"status": True} -@app.get("/query/{collection_name}") -def query_collection( - collection_name: str, - query: str, - k: Optional[int] = 4, +class QueryDocForm(BaseModel): + collection_name: str + query: str + k: Optional[int] = 4 + + +@app.post("/query/doc") +def query_doc( + form_data: QueryDocForm, user=Depends(get_current_user), ): try: collection = CHROMA_CLIENT.get_collection( - name=collection_name, + name=form_data.collection_name, ) - result = collection.query(query_texts=[query], n_results=k) - + result = collection.query(query_texts=[form_data.query], n_results=form_data.k) return result except Exception as e: print(e) @@ -118,6 +122,79 @@ def query_collection( ) +class QueryCollectionsForm(BaseModel): + collection_names: List[str] + query: str + k: Optional[int] = 4 + + +def merge_and_sort_query_results(query_results, k): + # Initialize lists to store combined data + combined_ids = [] + combined_distances = [] + combined_metadatas = [] + combined_documents = [] + + # Combine data from each dictionary + for data in query_results: + combined_ids.extend(data["ids"][0]) + combined_distances.extend(data["distances"][0]) + combined_metadatas.extend(data["metadatas"][0]) + combined_documents.extend(data["documents"][0]) + + # Create a list of tuples (distance, id, metadata, document) + combined = list( + zip(combined_distances, combined_ids, combined_metadatas, combined_documents) + ) + + # Sort the list based on distances + combined.sort(key=lambda x: x[0]) + + # Unzip the sorted list + sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) + + # Slicing the lists to include only k elements + sorted_distances = list(sorted_distances)[:k] + sorted_ids = list(sorted_ids)[:k] + sorted_metadatas = list(sorted_metadatas)[:k] + sorted_documents = list(sorted_documents)[:k] + + # Create the output dictionary + merged_query_results = { + "ids": [sorted_ids], + "distances": [sorted_distances], + "metadatas": [sorted_metadatas], + "documents": [sorted_documents], + "embeddings": None, + "uris": None, + "data": None, + } + + return merged_query_results + + +@app.post("/query/collection") +def query_collection( + form_data: QueryCollectionsForm, + user=Depends(get_current_user), +): + results = [] + + for collection_name in form_data.collection_names: + try: + collection = CHROMA_CLIENT.get_collection( + name=collection_name, + ) + result = collection.query( + query_texts=[form_data.query], n_results=form_data.k + ) + results.append(result) + except: + pass + + return merge_and_sort_query_results(results, form_data.k) + + @app.post("/web") def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" diff --git a/backend/apps/web/models/documents.py b/backend/apps/web/models/documents.py index 0196c38b..6a372b2c 100644 --- a/backend/apps/web/models/documents.py +++ b/backend/apps/web/models/documents.py @@ -44,6 +44,16 @@ class DocumentModel(BaseModel): #################### +class DocumentResponse(BaseModel): + collection_name: str + name: str + title: str + filename: str + content: Optional[dict] = None + user_id: str + timestamp: int # timestamp in epoch + + class DocumentUpdateForm(BaseModel): name: str title: str @@ -111,6 +121,26 @@ class DocumentsTable: print(e) return None + def update_doc_content_by_name( + self, name: str, updated: dict + ) -> Optional[DocumentModel]: + try: + doc = self.get_doc_by_name(name) + doc_content = json.loads(doc.content if doc.content else "{}") + doc_content = {**doc_content, **updated} + + query = Document.update( + content=json.dumps(doc_content), + timestamp=int(time.time()), + ).where(Document.name == name) + query.execute() + + doc = Document.get(Document.name == name) + return DocumentModel(**model_to_dict(doc)) + except Exception as e: + print(e) + return None + def delete_doc_by_name(self, name: str) -> bool: try: query = Document.delete().where((Document.name == name)) diff --git a/backend/apps/web/routers/documents.py b/backend/apps/web/routers/documents.py index c64ee8f0..3b6434d1 100644 --- a/backend/apps/web/routers/documents.py +++ b/backend/apps/web/routers/documents.py @@ -11,6 +11,7 @@ from apps.web.models.documents import ( DocumentForm, DocumentUpdateForm, DocumentModel, + DocumentResponse, ) from utils.utils import get_current_user @@ -23,9 +24,18 @@ router = APIRouter() ############################ -@router.get("/", response_model=List[DocumentModel]) +@router.get("/", response_model=List[DocumentResponse]) async def get_documents(user=Depends(get_current_user)): - return Documents.get_docs() + docs = [ + DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) + for doc in Documents.get_docs() + ] + return docs ############################ @@ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)): ############################ -@router.post("/create", response_model=Optional[DocumentModel]) +@router.post("/create", response_model=Optional[DocumentResponse]) async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)): if user.role != "admin": raise HTTPException( @@ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) doc = Documents.insert_new_doc(user.id, form_data) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user) ############################ -@router.get("/name/{name}", response_model=Optional[DocumentModel]) +@router.get("/name/{name}", response_model=Optional[DocumentResponse]) async def get_doc_by_name(name: str, user=Depends(get_current_user)): doc = Documents.get_doc_by_name(name) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + + +############################ +# TagDocByName +############################ + + +class TagDocumentForm(BaseModel): + name: str + tags: List[dict] + + +@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse]) +async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)): + doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags}) + + if doc: + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)): ############################ -@router.post("/name/{name}/update", response_model=Optional[DocumentModel]) +@router.post("/name/{name}/update", response_model=Optional[DocumentResponse]) async def update_doc_by_name( name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user) ): @@ -94,7 +142,12 @@ async def update_doc_by_name( doc = Documents.update_doc_by_name(name, form_data) if doc: - return doc + return DocumentResponse( + **{ + **doc.model_dump(), + "content": json.loads(doc.content if doc.content else "{}"), + } + ) else: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/src/lib/apis/documents/index.ts b/src/lib/apis/documents/index.ts index fb208ea4..2f7fb2b9 100644 --- a/src/lib/apis/documents/index.ts +++ b/src/lib/apis/documents/index.ts @@ -144,6 +144,47 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda return res; }; +type TagDocForm = { + name: string; + tags: string[]; +}; + +export const tagDocByName = async (token: string, name: string, form: TagDocForm) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/documents/name/${name}/tags`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + name: form.name, + tags: form.tags + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err.detail; + + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const deleteDocByName = async (token: string, name: string) => { let error = null; diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index 8a2b8cb4..3f4f30bf 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -64,30 +64,64 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string return res; }; -export const queryVectorDB = async ( +export const queryDoc = async ( token: string, collection_name: string, query: string, k: number ) => { let error = null; - const searchParams = new URLSearchParams(); - searchParams.set('query', query); - if (k) { - searchParams.set('k', k.toString()); + const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_name: collection_name, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; } - const res = await fetch( - `${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`, - { - method: 'GET', - headers: { - Accept: 'application/json', - authorization: `Bearer ${token}` - } - } - ) + return res; +}; + +export const queryCollection = async ( + token: string, + collection_names: string, + query: string, + k: number +) => { + let error = null; + + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_names: collection_names, + query: query, + k: k + }) + }) .then(async (res) => { if (!res.ok) throw await res.json(); return res.json(); diff --git a/src/lib/components/AddFilesPlaceholder.svelte b/src/lib/components/AddFilesPlaceholder.svelte index 7cc51c6f..3cbd0454 100644 --- a/src/lib/components/AddFilesPlaceholder.svelte +++ b/src/lib/components/AddFilesPlaceholder.svelte @@ -1,6 +1,8 @@
📄
Add Files
-
- Drop any files here to add to the conversation -
+
+ Drop any files here to add to the conversation +
+
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte index 2aec58d4..20eab5b6 100644 --- a/src/lib/components/chat/MessageInput.svelte +++ b/src/lib/components/chat/MessageInput.svelte @@ -295,7 +295,7 @@ files = [ ...files, { - type: 'doc', + type: e?.detail?.type ?? 'doc', ...e.detail, upload_status: true } @@ -446,6 +446,34 @@
Document
+ {:else if file.type === 'collection'} +
+
+ + + + +
+ +
+
+ {file?.title ?? `#${file.name}`} +
+ +
Collection
+
+
{/if}
diff --git a/src/lib/components/chat/MessageInput/Documents.svelte b/src/lib/components/chat/MessageInput/Documents.svelte index 5f252b3d..587e59c0 100644 --- a/src/lib/components/chat/MessageInput/Documents.svelte +++ b/src/lib/components/chat/MessageInput/Documents.svelte @@ -10,14 +10,50 @@ const dispatch = createEventDispatcher(); let selectedIdx = 0; + + let filteredItems = []; let filteredDocs = []; + let collections = []; + + $: collections = [ + ...($documents.length > 0 + ? [ + { + name: 'All Documents', + type: 'collection', + title: 'All Documents', + collection_names: $documents.map((doc) => doc.collection_name) + } + ] + : []), + ...$documents + .reduce((a, e, i, arr) => { + return [...new Set([...a, ...(e?.content?.tags ?? []).map((tag) => tag.name)])]; + }, []) + .map((tag) => ({ + name: tag, + type: 'collection', + collection_names: $documents + .filter((doc) => (doc?.content?.tags ?? []).map((tag) => tag.name).includes(tag)) + .map((doc) => doc.collection_name) + })) + ]; + + $: filteredCollections = collections + .filter((collection) => collection.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')) + .sort((a, b) => a.name.localeCompare(b.name)); + $: filteredDocs = $documents - .filter((p) => p.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')) + .filter((doc) => doc.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? '')) .sort((a, b) => a.title.localeCompare(b.title)); + $: filteredItems = [...filteredCollections, ...filteredDocs]; + $: if (prompt) { selectedIdx = 0; + + console.log(filteredCollections); } export const selectUp = () => { @@ -25,7 +61,7 @@ }; export const selectDown = () => { - selectedIdx = Math.min(selectedIdx + 1, filteredDocs.length - 1); + selectedIdx = Math.min(selectedIdx + 1, filteredItems.length - 1); }; const confirmSelect = async (doc) => { @@ -51,7 +87,7 @@ }; -{#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')} +{#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
@@ -60,7 +96,7 @@
- {#each filteredDocs as doc, docIdx} + {#each filteredItems as doc, docIdx} {/each} diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte index eaf4dac8..a1993648 100644 --- a/src/lib/components/chat/Messages/UserMessage.svelte +++ b/src/lib/components/chat/Messages/UserMessage.svelte @@ -117,6 +117,35 @@
Document
+ {:else if file.type === 'collection'} + {/if}
{/each} diff --git a/src/lib/components/common/Tags.svelte b/src/lib/components/common/Tags.svelte new file mode 100644 index 00000000..426678fb --- /dev/null +++ b/src/lib/components/common/Tags.svelte @@ -0,0 +1,24 @@ + + +
+ { + deleteTag(e.detail); + }} + /> + + { + addTag(e.detail); + }} + /> +
diff --git a/src/lib/components/common/Tags/TagInput.svelte b/src/lib/components/common/Tags/TagInput.svelte new file mode 100644 index 00000000..64726cf8 --- /dev/null +++ b/src/lib/components/common/Tags/TagInput.svelte @@ -0,0 +1,64 @@ + + +
+ {#if showTagInput} +
+ + + +
+ + + {/if} + + +
diff --git a/src/lib/components/common/Tags/TagList.svelte b/src/lib/components/common/Tags/TagList.svelte new file mode 100644 index 00000000..66a0b060 --- /dev/null +++ b/src/lib/components/common/Tags/TagList.svelte @@ -0,0 +1,33 @@ + + +{#each tags as tag} +
+
+ {tag.name} +
+ +
+{/each} diff --git a/src/lib/components/documents/EditDocModal.svelte b/src/lib/components/documents/EditDocModal.svelte index 35ea1539..4974da5a 100644 --- a/src/lib/components/documents/EditDocModal.svelte +++ b/src/lib/components/documents/EditDocModal.svelte @@ -3,16 +3,22 @@ import dayjs from 'dayjs'; import { onMount } from 'svelte'; - import { getDocs, updateDocByName } from '$lib/apis/documents'; + import { getDocs, tagDocByName, updateDocByName } from '$lib/apis/documents'; import Modal from '../common/Modal.svelte'; import { documents } from '$lib/stores'; + import TagInput from '../common/Tags/TagInput.svelte'; + import Tags from '../common/Tags.svelte'; + import { addTagById } from '$lib/apis/chats'; export let show = false; export let selectedDoc; + let tags = []; + let doc = { name: '', - title: '' + title: '', + content: null }; const submitHandler = async () => { @@ -30,9 +36,37 @@ } }; + const addTagHandler = async (tagName) => { + if (!tags.find((tag) => tag.name === tagName) && tagName !== '') { + tags = [...tags, { name: tagName }]; + + await tagDocByName(localStorage.token, doc.name, { + name: doc.name, + tags: tags + }); + + documents.set(await getDocs(localStorage.token)); + } else { + console.log('tag already exists'); + } + }; + + const deleteTagHandler = async (tagName) => { + tags = tags.filter((tag) => tag.name !== tagName); + + await tagDocByName(localStorage.token, doc.name, { + name: doc.name, + tags: tags + }); + + documents.set(await getDocs(localStorage.token)); + }; + onMount(() => { if (selectedDoc) { doc = JSON.parse(JSON.stringify(selectedDoc)); + + tags = doc?.content?.tags ?? []; } }); @@ -112,6 +146,12 @@ />
+ +
+
Tags
+ + +
diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index 3fc0ffa0..e829aeb3 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -6,6 +6,8 @@ import { getChatById } from '$lib/apis/chats'; import { chatId, modelfiles } from '$lib/stores'; import ShareChatModal from '../chat/ShareChatModal.svelte'; + import TagInput from '../common/Tags/TagInput.svelte'; + import Tags from '../common/Tags.svelte'; export let initNewChat: Function; export let title: string = 'Ollama Web UI'; @@ -61,21 +63,6 @@ saveAs(blob, `chat-${chat.title}.txt`); }; - - const addTagHandler = () => { - // if (!tags.find((e) => e.name === tagName)) { - // tags = [ - // ...tags, - // { - // name: JSON.parse(JSON.stringify(tagName)) - // } - // ]; - // } - - addTag(tagName); - tagName = ''; - showTagInput = false; - }; @@ -116,87 +103,7 @@
{#if shareEnabled} -
- {#each tags as tag} -
-
- {tag.name} -
- -
- {/each} - -
- {#if showTagInput} -
- - - -
- - - {/if} - - -
-
+
+
- { - if (inputFiles && inputFiles.length > 0) { - const file = inputFiles[0]; - if ( - SUPPORTED_FILE_TYPE.includes(file['type']) || - SUPPORTED_FILE_EXTENSIONS.includes(file.name.split('.').at(-1)) - ) { - uploadDoc(file); - } else { - toast.error( - `Unknown File Type '${file['type']}', but accepting and treating as plain text` - ); - uploadDoc(file); - } + {#if tags.length > 0} +
+ + {#each tags as tag} + + {/each} +
+ {/if} - inputFiles = null; - e.target.value = ''; - } else { - toast.error(`File not found.`); - } - }} - /> - -
+ - {#each $documents.filter((p) => query === '' || p.name.includes(query)) as doc} -
-
+ {#each $documents.filter((doc) => (selectedTag === '' || (doc?.content?.tags ?? []) + .map((tag) => tag.name) + .includes(selectedTag)) && (query === '' || doc.name.includes(query))) as doc} +
@@ -330,106 +417,97 @@
{/each} - {#if $documents.length != 0} -
-
-
- { - console.log(importFiles); +
- const reader = new FileReader(); - reader.onload = async (event) => { - const savedDocs = JSON.parse(event.target.result); - console.log(savedDocs); +
+
+ { + console.log(importFiles); - for (const doc of savedDocs) { - await createNewDoc( - localStorage.token, - doc.collection_name, - doc.filename, - doc.name, - doc.title - ).catch((error) => { - toast.error(error); - return null; - }); - } + const reader = new FileReader(); + reader.onload = async (event) => { + const savedDocs = JSON.parse(event.target.result); + console.log(savedDocs); - await documents.set(await getDocs(localStorage.token)); - }; + for (const doc of savedDocs) { + await createNewDoc( + localStorage.token, + doc.collection_name, + doc.filename, + doc.name, + doc.title + ).catch((error) => { + toast.error(error); + return null; + }); + } - reader.readAsText(importFiles[0]); - }} - /> + await documents.set(await getDocs(localStorage.token)); + }; - - - - - -
+
Import Documents Mapping
+ +
+ + + +
+ + +
- {/if} +