diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index de00a581..95535274 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -97,15 +97,15 @@ async def get_status():
return {"status": True}
-class QueryCollectionForm(BaseModel):
+class QueryDocForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = 4
-@app.post("/query/collection")
-def query_collection(
- form_data: QueryCollectionForm,
+@app.post("/query/doc")
+def query_doc(
+ form_data: QueryDocForm,
user=Depends(get_current_user),
):
try:
@@ -173,8 +173,8 @@ def merge_and_sort_query_results(query_results, k):
return merged_query_results
-@app.post("/query/collections")
-def query_collections(
+@app.post("/query/collection")
+def query_collection(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts
index ca14371f..3f4f30bf 100644
--- a/src/lib/apis/rag/index.ts
+++ b/src/lib/apis/rag/index.ts
@@ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
return res;
};
-export const queryCollection = async (
+export const queryDoc = async (
token: string,
collection_name: string,
query: string,
@@ -72,6 +72,43 @@ export const queryCollection = async (
) => {
let error = null;
+ const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, {
+ method: 'POST',
+ headers: {
+ Accept: 'application/json',
+ 'Content-Type': 'application/json',
+ authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ collection_name: collection_name,
+ query: query,
+ k: k
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
+export const queryCollection = async (
+ token: string,
+ collection_names: string,
+ query: string,
+ k: number
+) => {
+ let error = null;
+
const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
method: 'POST',
headers: {
@@ -80,7 +117,7 @@ export const queryCollection = async (
authorization: `Bearer ${token}`
},
body: JSON.stringify({
- collection_name: collection_name,
+ collection_names: collection_names,
query: query,
k: k
})
diff --git a/src/lib/components/chat/Messages/UserMessage.svelte b/src/lib/components/chat/Messages/UserMessage.svelte
index 761ba41c..0e0fc332 100644
--- a/src/lib/components/chat/Messages/UserMessage.svelte
+++ b/src/lib/components/chat/Messages/UserMessage.svelte
@@ -117,6 +117,35 @@
Document
+ {:else if file.type === 'collection'}
+
{/if}
{/each}
diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte
index 956b6cb0..376c4e37 100644
--- a/src/routes/(app)/+page.svelte
+++ b/src/routes/(app)/+page.svelte
@@ -28,7 +28,7 @@
getTagsById,
updateChatById
} from '$lib/apis/chats';
- import { queryCollection } from '$lib/apis/rag';
+ import { queryCollection, queryDoc } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
@@ -224,7 +224,9 @@
const docs = messages
.filter((message) => message?.files ?? null)
- .map((message) => message.files.filter((item) => item.type === 'doc'))
+ .map((message) =>
+ message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+ )
.flat(1);
console.log(docs);
@@ -234,12 +236,21 @@
let relevantContexts = await Promise.all(
docs.map(async (doc) => {
- return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch(
- (error) => {
- console.log(error);
- return null;
- }
- );
+ if (doc.type === 'collection') {
+ return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
+ (error) => {
+ console.log(error);
+ return null;
+ }
+ );
+ } else {
+ return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
+ (error) => {
+ console.log(error);
+ return null;
+ }
+ );
+ }
})
);
relevantContexts = relevantContexts.filter((context) => context);
diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte
index fac8a01c..83e72c62 100644
--- a/src/routes/(app)/c/[id]/+page.svelte
+++ b/src/routes/(app)/c/[id]/+page.svelte
@@ -29,7 +29,7 @@
getTagsById,
updateChatById
} from '$lib/apis/chats';
- import { queryCollection } from '$lib/apis/rag';
+ import { queryCollection, queryDoc } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
import MessageInput from '$lib/components/chat/MessageInput.svelte';
@@ -238,7 +238,9 @@
const docs = messages
.filter((message) => message?.files ?? null)
- .map((message) => message.files.filter((item) => item.type === 'doc'))
+ .map((message) =>
+ message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
+ )
.flat(1);
console.log(docs);
@@ -248,12 +250,21 @@
let relevantContexts = await Promise.all(
docs.map(async (doc) => {
- return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch(
- (error) => {
- console.log(error);
- return null;
- }
- );
+ if (doc.type === 'collection') {
+ return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
+ (error) => {
+ console.log(error);
+ return null;
+ }
+ );
+ } else {
+ return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
+ (error) => {
+ console.log(error);
+ return null;
+ }
+ );
+ }
})
);
relevantContexts = relevantContexts.filter((context) => context);