diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index 85bc995a..95535274 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -10,6 +10,7 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
+from typing import List
# from chromadb.utils import embedding_functions
@@ -96,19 +97,22 @@ async def get_status():
return {"status": True}
-@app.get("/query/{collection_name}")
-def query_collection(
- collection_name: str,
- query: str,
- k: Optional[int] = 4,
+class QueryDocForm(BaseModel):
+ collection_name: str
+ query: str
+ k: Optional[int] = 4
+
+
+@app.post("/query/doc")
+def query_doc(
+ form_data: QueryDocForm,
user=Depends(get_current_user),
):
try:
collection = CHROMA_CLIENT.get_collection(
- name=collection_name,
+ name=form_data.collection_name,
)
- result = collection.query(query_texts=[query], n_results=k)
-
+ result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
except Exception as e:
print(e)
@@ -118,6 +122,79 @@ def query_collection(
)
+class QueryCollectionsForm(BaseModel):
+ collection_names: List[str]
+ query: str
+ k: Optional[int] = 4
+
+
+def merge_and_sort_query_results(query_results, k):
+ # Initialize lists to store combined data
+ combined_ids = []
+ combined_distances = []
+ combined_metadatas = []
+ combined_documents = []
+
+ # Combine data from each dictionary
+ for data in query_results:
+ combined_ids.extend(data["ids"][0])
+ combined_distances.extend(data["distances"][0])
+ combined_metadatas.extend(data["metadatas"][0])
+ combined_documents.extend(data["documents"][0])
+
+ # Create a list of tuples (distance, id, metadata, document)
+ combined = list(
+ zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
+ )
+
+ # Sort the list based on distances
+ combined.sort(key=lambda x: x[0])
+
+ # Unzip the sorted list
+ sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
+
+ # Slicing the lists to include only k elements
+ sorted_distances = list(sorted_distances)[:k]
+ sorted_ids = list(sorted_ids)[:k]
+ sorted_metadatas = list(sorted_metadatas)[:k]
+ sorted_documents = list(sorted_documents)[:k]
+
+ # Create the output dictionary
+ merged_query_results = {
+ "ids": [sorted_ids],
+ "distances": [sorted_distances],
+ "metadatas": [sorted_metadatas],
+ "documents": [sorted_documents],
+ "embeddings": None,
+ "uris": None,
+ "data": None,
+ }
+
+ return merged_query_results
+
+
+@app.post("/query/collection")
+def query_collection(
+ form_data: QueryCollectionsForm,
+ user=Depends(get_current_user),
+):
+ results = []
+
+ for collection_name in form_data.collection_names:
+ try:
+ collection = CHROMA_CLIENT.get_collection(
+ name=collection_name,
+ )
+ result = collection.query(
+ query_texts=[form_data.query], n_results=form_data.k
+ )
+ results.append(result)
+ except:
+ pass
+
+ return merge_and_sort_query_results(results, form_data.k)
+
+
@app.post("/web")
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
diff --git a/backend/apps/web/models/documents.py b/backend/apps/web/models/documents.py
index 0196c38b..6a372b2c 100644
--- a/backend/apps/web/models/documents.py
+++ b/backend/apps/web/models/documents.py
@@ -44,6 +44,16 @@ class DocumentModel(BaseModel):
####################
+class DocumentResponse(BaseModel):
+ collection_name: str
+ name: str
+ title: str
+ filename: str
+ content: Optional[dict] = None
+ user_id: str
+ timestamp: int # timestamp in epoch
+
+
class DocumentUpdateForm(BaseModel):
name: str
title: str
@@ -111,6 +121,26 @@ class DocumentsTable:
print(e)
return None
+ def update_doc_content_by_name(
+ self, name: str, updated: dict
+ ) -> Optional[DocumentModel]:
+ try:
+ doc = self.get_doc_by_name(name)
+ doc_content = json.loads(doc.content if doc.content else "{}")
+ doc_content = {**doc_content, **updated}
+
+ query = Document.update(
+ content=json.dumps(doc_content),
+ timestamp=int(time.time()),
+ ).where(Document.name == name)
+ query.execute()
+
+ doc = Document.get(Document.name == name)
+ return DocumentModel(**model_to_dict(doc))
+ except Exception as e:
+ print(e)
+ return None
+
def delete_doc_by_name(self, name: str) -> bool:
try:
query = Document.delete().where((Document.name == name))
diff --git a/backend/apps/web/routers/documents.py b/backend/apps/web/routers/documents.py
index c64ee8f0..3b6434d1 100644
--- a/backend/apps/web/routers/documents.py
+++ b/backend/apps/web/routers/documents.py
@@ -11,6 +11,7 @@ from apps.web.models.documents import (
DocumentForm,
DocumentUpdateForm,
DocumentModel,
+ DocumentResponse,
)
from utils.utils import get_current_user
@@ -23,9 +24,18 @@ router = APIRouter()
############################
-@router.get("/", response_model=List[DocumentModel])
+@router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user)):
- return Documents.get_docs()
+ docs = [
+ DocumentResponse(
+ **{
+ **doc.model_dump(),
+ "content": json.loads(doc.content if doc.content else "{}"),
+ }
+ )
+ for doc in Documents.get_docs()
+ ]
+ return docs
############################
@@ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)):
############################
-@router.post("/create", response_model=Optional[DocumentModel])
+@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
@@ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)
doc = Documents.insert_new_doc(user.id, form_data)
if doc:
- return doc
+ return DocumentResponse(
+ **{
+ **doc.model_dump(),
+ "content": json.loads(doc.content if doc.content else "{}"),
+ }
+ )
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)
############################
-@router.get("/name/{name}", response_model=Optional[DocumentModel])
+@router.get("/name/{name}", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
if doc:
- return doc
+ return DocumentResponse(
+ **{
+ **doc.model_dump(),
+ "content": json.loads(doc.content if doc.content else "{}"),
+ }
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# TagDocByName
+############################
+
+
+class TagDocumentForm(BaseModel):
+ name: str
+ tags: List[dict]
+
+
+@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
+async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
+ doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
+
+ if doc:
+ return DocumentResponse(
+ **{
+ **doc.model_dump(),
+ "content": json.loads(doc.content if doc.content else "{}"),
+ }
+ )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
############################
-@router.post("/name/{name}/update", response_model=Optional[DocumentModel])
+@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
):
@@ -94,7 +142,12 @@ async def update_doc_by_name(
doc = Documents.update_doc_by_name(name, form_data)
if doc:
- return doc
+ return DocumentResponse(
+ **{
+ **doc.model_dump(),
+ "content": json.loads(doc.content if doc.content else "{}"),
+ }
+ )
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
diff --git a/src/lib/apis/documents/index.ts b/src/lib/apis/documents/index.ts
index fb208ea4..2f7fb2b9 100644
--- a/src/lib/apis/documents/index.ts
+++ b/src/lib/apis/documents/index.ts
@@ -144,6 +144,47 @@ export const updateDocByName = async (token: string, name: string, form: DocUpda
return res;
};
+type TagDocForm = {
+ name: string;
+ tags: string[];
+};
+
+export const tagDocByName = async (token: string, name: string, form: TagDocForm) => {
+ let error = null;
+
+ const res = await fetch(`${WEBUI_API_BASE_URL}/documents/name/${name}/tags`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ name: form.name,
+ tags: form.tags
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .then((json) => {
+ return json;
+ })
+ .catch((err) => {
+ error = err.detail;
+
+ console.log(err);
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
export const deleteDocByName = async (token: string, name: string) => {
let error = null;
diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts
index 8a2b8cb4..3f4f30bf 100644
--- a/src/lib/apis/rag/index.ts
+++ b/src/lib/apis/rag/index.ts
@@ -64,30 +64,64 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
return res;
};
-export const queryVectorDB = async (
+export const queryDoc = async (
token: string,
collection_name: string,
query: string,
k: number
) => {
let error = null;
- const searchParams = new URLSearchParams();
- searchParams.set('query', query);
- if (k) {
- searchParams.set('k', k.toString());
+ const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ collection_name: collection_name,
+ query: query,
+ k: k
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
}
- const res = await fetch(
- `${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`,
- {
- method: 'GET',
- headers: {
- Accept: 'application/json',
- authorization: `Bearer ${token}`
- }
- }
- )
+ return res;
+};
+
+export const queryCollection = async (
+ token: string,
+ collection_names: string,
+ query: string,
+ k: number
+) => {
+ let error = null;
+
+ const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ collection_names: collection_names,
+ query: query,
+ k: k
+ })
+ })
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
diff --git a/src/lib/components/AddFilesPlaceholder.svelte b/src/lib/components/AddFilesPlaceholder.svelte
index 7cc51c6f..3cbd0454 100644
--- a/src/lib/components/AddFilesPlaceholder.svelte
+++ b/src/lib/components/AddFilesPlaceholder.svelte
@@ -1,6 +1,8 @@
📄
Add Files
-
- Drop any files here to add to the conversation
-
+
+ Drop any files here to add to the conversation
+
+
diff --git a/src/lib/components/chat/MessageInput.svelte b/src/lib/components/chat/MessageInput.svelte
index 2aec58d4..20eab5b6 100644
--- a/src/lib/components/chat/MessageInput.svelte
+++ b/src/lib/components/chat/MessageInput.svelte
@@ -295,7 +295,7 @@
files = [
...files,
{
- type: 'doc',
+ type: e?.detail?.type ?? 'doc',
...e.detail,
upload_status: true
}
@@ -446,6 +446,34 @@
Document
+ {:else if file.type === 'collection'}
+
+
+
+
+
+ {file?.title ?? `#${file.name}`}
+
+
+
Collection
+
+
{/if}
diff --git a/src/lib/components/chat/MessageInput/Documents.svelte b/src/lib/components/chat/MessageInput/Documents.svelte
index 5f252b3d..587e59c0 100644
--- a/src/lib/components/chat/MessageInput/Documents.svelte
+++ b/src/lib/components/chat/MessageInput/Documents.svelte
@@ -10,14 +10,50 @@
const dispatch = createEventDispatcher();
let selectedIdx = 0;
+
+ let filteredItems = [];
let filteredDocs = [];
+ let collections = [];
+
+ $: collections = [
+ ...($documents.length > 0
+ ? [
+ {
+ name: 'All Documents',
+ type: 'collection',
+ title: 'All Documents',
+ collection_names: $documents.map((doc) => doc.collection_name)
+ }
+ ]
+ : []),
+ ...$documents
+ .reduce((a, e, i, arr) => {
+ return [...new Set([...a, ...(e?.content?.tags ?? []).map((tag) => tag.name)])];
+ }, [])
+ .map((tag) => ({
+ name: tag,
+ type: 'collection',
+ collection_names: $documents
+ .filter((doc) => (doc?.content?.tags ?? []).map((tag) => tag.name).includes(tag))
+ .map((doc) => doc.collection_name)
+ }))
+ ];
+
+ $: filteredCollections = collections
+ .filter((collection) => collection.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? ''))
+ .sort((a, b) => a.name.localeCompare(b.name));
+
$: filteredDocs = $documents
- .filter((p) => p.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? ''))
+ .filter((doc) => doc.name.includes(prompt.split(' ')?.at(0)?.substring(1) ?? ''))
.sort((a, b) => a.title.localeCompare(b.title));
+ $: filteredItems = [...filteredCollections, ...filteredDocs];
+
$: if (prompt) {
selectedIdx = 0;
+
+ console.log(filteredCollections);
}
export const selectUp = () => {
@@ -25,7 +61,7 @@
};
export const selectDown = () => {
- selectedIdx = Math.min(selectedIdx + 1, filteredDocs.length - 1);
+ selectedIdx = Math.min(selectedIdx + 1, filteredItems.length - 1);
};
const confirmSelect = async (doc) => {
@@ -51,7 +87,7 @@
};
-{#if filteredDocs.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
+{#if filteredItems.length > 0 || prompt.split(' ')?.at(0)?.substring(1).startsWith('http')}
@@ -60,7 +96,7 @@
- {#each filteredDocs as doc, docIdx}
+ {#each filteredItems as doc, docIdx}
{/each}
diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte
index eaf4dac8..a1993648 100644
--- a/src/lib/components/chat/Messages/UserMessage.svelte
+++ b/src/lib/components/chat/Messages/UserMessage.svelte
@@ -117,6 +117,35 @@
Document
+ {:else if file.type === 'collection'}
+
{/if}
{/each}
diff --git a/src/lib/components/common/Tags.svelte b/src/lib/components/common/Tags.svelte
new file mode 100644
index 00000000..426678fb
--- /dev/null
+++ b/src/lib/components/common/Tags.svelte
@@ -0,0 +1,24 @@
+
+
+
+ {
+ deleteTag(e.detail);
+ }}
+ />
+
+ {
+ addTag(e.detail);
+ }}
+ />
+
diff --git a/src/lib/components/common/Tags/TagInput.svelte b/src/lib/components/common/Tags/TagInput.svelte
new file mode 100644
index 00000000..64726cf8
--- /dev/null
+++ b/src/lib/components/common/Tags/TagInput.svelte
@@ -0,0 +1,64 @@
+
+
+
+ {#if showTagInput}
+
+
+
+
+
+
+
+ {/if}
+
+
+
diff --git a/src/lib/components/common/Tags/TagList.svelte b/src/lib/components/common/Tags/TagList.svelte
new file mode 100644
index 00000000..66a0b060
--- /dev/null
+++ b/src/lib/components/common/Tags/TagList.svelte
@@ -0,0 +1,33 @@
+
+
+{#each tags as tag}
+
+
+ {tag.name}
+
+
+
+{/each}
diff --git a/src/lib/components/documents/EditDocModal.svelte b/src/lib/components/documents/EditDocModal.svelte
index 35ea1539..4974da5a 100644
--- a/src/lib/components/documents/EditDocModal.svelte
+++ b/src/lib/components/documents/EditDocModal.svelte
@@ -3,16 +3,22 @@
import dayjs from 'dayjs';
import { onMount } from 'svelte';
- import { getDocs, updateDocByName } from '$lib/apis/documents';
+ import { getDocs, tagDocByName, updateDocByName } from '$lib/apis/documents';
import Modal from '../common/Modal.svelte';
import { documents } from '$lib/stores';
+ import TagInput from '../common/Tags/TagInput.svelte';
+ import Tags from '../common/Tags.svelte';
+ import { addTagById } from '$lib/apis/chats';
export let show = false;
export let selectedDoc;
+ let tags = [];
+
let doc = {
name: '',
- title: ''
+ title: '',
+ content: null
};
const submitHandler = async () => {
@@ -30,9 +36,37 @@
}
};
+ const addTagHandler = async (tagName) => {
+ if (!tags.find((tag) => tag.name === tagName) && tagName !== '') {
+ tags = [...tags, { name: tagName }];
+
+ await tagDocByName(localStorage.token, doc.name, {
+ name: doc.name,
+ tags: tags
+ });
+
+ documents.set(await getDocs(localStorage.token));
+ } else {
+ console.log('tag already exists');
+ }
+ };
+
+ const deleteTagHandler = async (tagName) => {
+ tags = tags.filter((tag) => tag.name !== tagName);
+
+ await tagDocByName(localStorage.token, doc.name, {
+ name: doc.name,
+ tags: tags
+ });
+
+ documents.set(await getDocs(localStorage.token));
+ };
+
onMount(() => {
if (selectedDoc) {
doc = JSON.parse(JSON.stringify(selectedDoc));
+
+ tags = doc?.content?.tags ?? [];
}
});
@@ -112,6 +146,12 @@
/>
+
+
diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte
index 3fc0ffa0..e829aeb3 100644
--- a/src/lib/components/layout/Navbar.svelte
+++ b/src/lib/components/layout/Navbar.svelte
@@ -6,6 +6,8 @@
import { getChatById } from '$lib/apis/chats';
import { chatId, modelfiles } from '$lib/stores';
import ShareChatModal from '../chat/ShareChatModal.svelte';
+ import TagInput from '../common/Tags/TagInput.svelte';
+ import Tags from '../common/Tags.svelte';
export let initNewChat: Function;
export let title: string = 'Ollama Web UI';
@@ -61,21 +63,6 @@
saveAs(blob, `chat-${chat.title}.txt`);
};
-
- const addTagHandler = () => {
- // if (!tags.find((e) => e.name === tagName)) {
- // tags = [
- // ...tags,
- // {
- // name: JSON.parse(JSON.stringify(tagName))
- // }
- // ];
- // }
-
- addTag(tagName);
- tagName = '';
- showTagInput = false;
- };
@@ -116,87 +103,7 @@
{#if shareEnabled}
-
- {#each tags as tag}
-
-
- {tag.name}
-
-
-
- {/each}
-
-
- {#if showTagInput}
-
-
-
-
-
-
-
- {/if}
-
-
-
-
+