diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index de00a581..95535274 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -97,15 +97,15 @@ async def get_status(): return {"status": True} -class QueryCollectionForm(BaseModel): +class QueryDocForm(BaseModel): collection_name: str query: str k: Optional[int] = 4 -@app.post("/query/collection") -def query_collection( - form_data: QueryCollectionForm, +@app.post("/query/doc") +def query_doc( + form_data: QueryDocForm, user=Depends(get_current_user), ): try: @@ -173,8 +173,8 @@ def merge_and_sort_query_results(query_results, k): return merged_query_results -@app.post("/query/collections") -def query_collections( +@app.post("/query/collection") +def query_collection( form_data: QueryCollectionsForm, user=Depends(get_current_user), ): diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts index ca14371f..3f4f30bf 100644 --- a/src/lib/apis/rag/index.ts +++ b/src/lib/apis/rag/index.ts @@ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string return res; }; -export const queryCollection = async ( +export const queryDoc = async ( token: string, collection_name: string, query: string, @@ -72,6 +72,43 @@ export const queryCollection = async ( ) => { let error = null; + const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + authorization: `Bearer ${token}` + }, + body: JSON.stringify({ + collection_name: collection_name, + query: query, + k: k + }) + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .catch((err) => { + error = err.detail; + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const queryCollection = async ( + token: string, + collection_names: string, + query: string, + k: number +) => { + let error = null; + const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, { method: 'POST', headers: { @@ -80,7 +117,7 @@ export const queryCollection = async ( authorization: `Bearer ${token}` }, body: JSON.stringify({ - collection_name: collection_name, + collection_names: collection_names, query: query, k: k }) diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte index 761ba41c..0e0fc332 100644 --- a/src/lib/components/chat/Messages/UserMessage.svelte +++ b/src/lib/components/chat/Messages/UserMessage.svelte @@ -117,6 +117,35 @@
Document
+ {:else if file.type === 'collection'} + {/if} {/each} diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 956b6cb0..376c4e37 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -28,7 +28,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryCollection } from '$lib/apis/rag'; + import { queryCollection, queryDoc } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -224,7 +224,9 @@ const docs = messages .filter((message) => message?.files ?? null) - .map((message) => message.files.filter((item) => item.type === 'doc')) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) .flat(1); console.log(docs); @@ -234,12 +236,21 @@ let relevantContexts = await Promise.all( docs.map(async (doc) => { - return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); + if (doc.type === 'collection') { + return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } else { + return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } }) ); relevantContexts = relevantContexts.filter((context) => context); diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index fac8a01c..83e72c62 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -29,7 +29,7 @@ getTagsById, updateChatById } from '$lib/apis/chats'; - import { queryCollection } from '$lib/apis/rag'; + import { queryCollection, queryDoc } from '$lib/apis/rag'; import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import MessageInput from '$lib/components/chat/MessageInput.svelte'; @@ -238,7 +238,9 @@ const docs = messages .filter((message) => message?.files ?? null) - .map((message) => message.files.filter((item) => item.type === 'doc')) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) .flat(1); console.log(docs); @@ -248,12 +250,21 @@ let relevantContexts = await Promise.all( docs.map(async (doc) => { - return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch( - (error) => { - console.log(error); - return null; - } - ); + if (doc.type === 'collection') { + return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } else { + return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( + (error) => { + console.log(error); + return null; + } + ); + } }) ); relevantContexts = relevantContexts.filter((context) => context);