forked from open-webui/open-webui
feat: collection rag integration
This commit is contained in:
parent
7d2f788a3b
commit
683650ec00
5 changed files with 112 additions and 24 deletions
|
@ -97,15 +97,15 @@ async def get_status():
|
||||||
return {"status": True}
|
return {"status": True}
|
||||||
|
|
||||||
|
|
||||||
class QueryCollectionForm(BaseModel):
|
class QueryDocForm(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
query: str
|
query: str
|
||||||
k: Optional[int] = 4
|
k: Optional[int] = 4
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query/collection")
|
@app.post("/query/doc")
|
||||||
def query_collection(
|
def query_doc(
|
||||||
form_data: QueryCollectionForm,
|
form_data: QueryDocForm,
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
@ -173,8 +173,8 @@ def merge_and_sort_query_results(query_results, k):
|
||||||
return merged_query_results
|
return merged_query_results
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query/collections")
|
@app.post("/query/collection")
|
||||||
def query_collections(
|
def query_collection(
|
||||||
form_data: QueryCollectionsForm,
|
form_data: QueryCollectionsForm,
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
|
|
|
@ -64,7 +64,7 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
export const queryCollection = async (
|
export const queryDoc = async (
|
||||||
token: string,
|
token: string,
|
||||||
collection_name: string,
|
collection_name: string,
|
||||||
query: string,
|
query: string,
|
||||||
|
@ -72,6 +72,43 @@ export const queryCollection = async (
|
||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${RAG_API_BASE_URL}/query/doc`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
authorization: `Bearer ${token}`
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
collection_name: collection_name,
|
||||||
|
query: query,
|
||||||
|
k: k
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
error = err.detail;
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const queryCollection = async (
|
||||||
|
token: string,
|
||||||
|
collection_names: string,
|
||||||
|
query: string,
|
||||||
|
k: number
|
||||||
|
) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
|
const res = await fetch(`${RAG_API_BASE_URL}/query/collection`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
@ -80,7 +117,7 @@ export const queryCollection = async (
|
||||||
authorization: `Bearer ${token}`
|
authorization: `Bearer ${token}`
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
collection_name: collection_name,
|
collection_names: collection_names,
|
||||||
query: query,
|
query: query,
|
||||||
k: k
|
k: k
|
||||||
})
|
})
|
||||||
|
|
|
@ -117,6 +117,35 @@
|
||||||
<div class=" text-gray-500 text-sm">Document</div>
|
<div class=" text-gray-500 text-sm">Document</div>
|
||||||
</div>
|
</div>
|
||||||
</button>
|
</button>
|
||||||
|
{:else if file.type === 'collection'}
|
||||||
|
<button
|
||||||
|
class="h-16 w-[15rem] flex items-center space-x-3 px-2.5 dark:bg-gray-600 rounded-xl border border-gray-200 dark:border-none text-left"
|
||||||
|
type="button"
|
||||||
|
>
|
||||||
|
<div class="p-2.5 bg-red-400 text-white rounded-lg">
|
||||||
|
<svg
|
||||||
|
xmlns="http://www.w3.org/2000/svg"
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
fill="currentColor"
|
||||||
|
class="w-6 h-6"
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
d="M7.5 3.375c0-1.036.84-1.875 1.875-1.875h.375a3.75 3.75 0 0 1 3.75 3.75v1.875C13.5 8.161 14.34 9 15.375 9h1.875A3.75 3.75 0 0 1 21 12.75v3.375C21 17.16 20.16 18 19.125 18h-9.75A1.875 1.875 0 0 1 7.5 16.125V3.375Z"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
d="M15 5.25a5.23 5.23 0 0 0-1.279-3.434 9.768 9.768 0 0 1 6.963 6.963A5.23 5.23 0 0 0 17.25 7.5h-1.875A.375.375 0 0 1 15 7.125V5.25ZM4.875 6H6v10.125A3.375 3.375 0 0 0 9.375 19.5H16.5v1.125c0 1.035-.84 1.875-1.875 1.875h-9.75A1.875 1.875 0 0 1 3 20.625V7.875C3 6.839 3.84 6 4.875 6Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="flex flex-col justify-center -space-y-0.5">
|
||||||
|
<div class=" dark:text-gray-100 text-sm font-medium line-clamp-1">
|
||||||
|
#{file.name}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class=" text-gray-500 text-sm">Collection</div>
|
||||||
|
</div>
|
||||||
|
</button>
|
||||||
{/if}
|
{/if}
|
||||||
</div>
|
</div>
|
||||||
{/each}
|
{/each}
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
} from '$lib/apis/chats';
|
} from '$lib/apis/chats';
|
||||||
import { queryCollection } from '$lib/apis/rag';
|
import { queryCollection, queryDoc } from '$lib/apis/rag';
|
||||||
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
||||||
|
|
||||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||||
|
@ -224,7 +224,9 @@
|
||||||
|
|
||||||
const docs = messages
|
const docs = messages
|
||||||
.filter((message) => message?.files ?? null)
|
.filter((message) => message?.files ?? null)
|
||||||
.map((message) => message.files.filter((item) => item.type === 'doc'))
|
.map((message) =>
|
||||||
|
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
|
||||||
|
)
|
||||||
.flat(1);
|
.flat(1);
|
||||||
|
|
||||||
console.log(docs);
|
console.log(docs);
|
||||||
|
@ -234,12 +236,21 @@
|
||||||
|
|
||||||
let relevantContexts = await Promise.all(
|
let relevantContexts = await Promise.all(
|
||||||
docs.map(async (doc) => {
|
docs.map(async (doc) => {
|
||||||
return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch(
|
if (doc.type === 'collection') {
|
||||||
(error) => {
|
return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
|
||||||
console.log(error);
|
(error) => {
|
||||||
return null;
|
console.log(error);
|
||||||
}
|
return null;
|
||||||
);
|
}
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
|
||||||
|
(error) => {
|
||||||
|
console.log(error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
relevantContexts = relevantContexts.filter((context) => context);
|
relevantContexts = relevantContexts.filter((context) => context);
|
||||||
|
|
|
@ -29,7 +29,7 @@
|
||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
} from '$lib/apis/chats';
|
} from '$lib/apis/chats';
|
||||||
import { queryCollection } from '$lib/apis/rag';
|
import { queryCollection, queryDoc } from '$lib/apis/rag';
|
||||||
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
import { generateOpenAIChatCompletion } from '$lib/apis/openai';
|
||||||
|
|
||||||
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
import MessageInput from '$lib/components/chat/MessageInput.svelte';
|
||||||
|
@ -238,7 +238,9 @@
|
||||||
|
|
||||||
const docs = messages
|
const docs = messages
|
||||||
.filter((message) => message?.files ?? null)
|
.filter((message) => message?.files ?? null)
|
||||||
.map((message) => message.files.filter((item) => item.type === 'doc'))
|
.map((message) =>
|
||||||
|
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
|
||||||
|
)
|
||||||
.flat(1);
|
.flat(1);
|
||||||
|
|
||||||
console.log(docs);
|
console.log(docs);
|
||||||
|
@ -248,12 +250,21 @@
|
||||||
|
|
||||||
let relevantContexts = await Promise.all(
|
let relevantContexts = await Promise.all(
|
||||||
docs.map(async (doc) => {
|
docs.map(async (doc) => {
|
||||||
return await queryCollection(localStorage.token, doc.collection_name, query, 4).catch(
|
if (doc.type === 'collection') {
|
||||||
(error) => {
|
return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
|
||||||
console.log(error);
|
(error) => {
|
||||||
return null;
|
console.log(error);
|
||||||
}
|
return null;
|
||||||
);
|
}
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
|
||||||
|
(error) => {
|
||||||
|
console.log(error);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
relevantContexts = relevantContexts.filter((context) => context);
|
relevantContexts = relevantContexts.filter((context) => context);
|
||||||
|
|
Loading…
Reference in a new issue