forked from open-webui/open-webui
refac
This commit is contained in:
parent
485236624f
commit
50f7b20ac2
4 changed files with 91 additions and 58 deletions
|
@ -10,6 +10,7 @@ from fastapi import (
|
||||||
)
|
)
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import os, shutil
|
import os, shutil
|
||||||
|
from typing import List
|
||||||
|
|
||||||
# from chromadb.utils import embedding_functions
|
# from chromadb.utils import embedding_functions
|
||||||
|
|
||||||
|
@ -96,19 +97,22 @@ async def get_status():
|
||||||
return {"status": True}
|
return {"status": True}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/query/{collection_name}")
|
class QueryCollectionForm(BaseModel):
|
||||||
|
collection_name: str
|
||||||
|
query: str
|
||||||
|
k: Optional[int] = 4
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/query/collection")
|
||||||
def query_collection(
|
def query_collection(
|
||||||
collection_name: str,
|
form_data: QueryCollectionForm,
|
||||||
query: str,
|
|
||||||
k: Optional[int] = 4,
|
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
name=collection_name,
|
name=form_data.collection_name,
|
||||||
)
|
)
|
||||||
result = collection.query(query_texts=[query], n_results=k)
|
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
@ -118,6 +122,34 @@ def query_collection(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryCollectionsForm(BaseModel):
|
||||||
|
collection_names: List[str]
|
||||||
|
query: str
|
||||||
|
k: Optional[int] = 4
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/query/collections")
|
||||||
|
def query_collections(
|
||||||
|
form_data: QueryCollectionsForm,
|
||||||
|
user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for collection_name in form_data.collection_names:
|
||||||
|
try:
|
||||||
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
|
name=collection_name,
|
||||||
|
)
|
||||||
|
result = collection.query(
|
||||||
|
query_texts=[form_data.query], n_results=form_data.k
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
@app.post("/web")
|
@app.post("/web")
|
||||||
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
|
||||||
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
|
||||||
|
|
|
@ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
|
||||||
|
|
||||||
export const queryVectorDB = async (
|
export const queryVectorDB = async (
|
||||||
token: string,
|
token: string,
|
||||||
collection_name: string,
|
collection_names: string[],
|
||||||
query: string,
|
query: string,
|
||||||
k: number
|
k: number
|
||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
const searchParams = new URLSearchParams();
|
|
||||||
|
|
||||||
searchParams.set('query', query);
|
const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, {
|
||||||
if (k) {
|
method: 'POST',
|
||||||
searchParams.set('k', k.toString());
|
headers: {
|
||||||
}
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
const res = await fetch(
|
authorization: `Bearer ${token}`
|
||||||
`${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`,
|
},
|
||||||
{
|
body: JSON.stringify({
|
||||||
method: 'GET',
|
collection_names: collection_names,
|
||||||
headers: {
|
query: query,
|
||||||
Accept: 'application/json',
|
k: k
|
||||||
authorization: `Bearer ${token}`
|
})
|
||||||
}
|
})
|
||||||
}
|
|
||||||
)
|
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
if (!res.ok) throw await res.json();
|
if (!res.ok) throw await res.json();
|
||||||
return res.json();
|
return res.json();
|
||||||
|
|
|
@ -232,26 +232,28 @@
|
||||||
processing = 'Reading';
|
processing = 'Reading';
|
||||||
const query = history.messages[parentId].content;
|
const query = history.messages[parentId].content;
|
||||||
|
|
||||||
let relevantContexts = await Promise.all(
|
let relevantContexts = await queryVectorDB(
|
||||||
docs.map(async (doc) => {
|
localStorage.token,
|
||||||
return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch(
|
docs.map((d) => d.collection_name),
|
||||||
(error) => {
|
query,
|
||||||
console.log(error);
|
4
|
||||||
return null;
|
).catch((error) => {
|
||||||
}
|
console.log(error);
|
||||||
);
|
return null;
|
||||||
})
|
});
|
||||||
);
|
|
||||||
relevantContexts = relevantContexts.filter((context) => context);
|
|
||||||
|
|
||||||
const contextString = relevantContexts.reduce((a, context, i, arr) => {
|
if (relevantContexts) {
|
||||||
return `${a}${context.documents.join(' ')}\n`;
|
relevantContexts = relevantContexts.filter((context) => context);
|
||||||
}, '');
|
|
||||||
|
|
||||||
console.log(contextString);
|
const contextString = relevantContexts.reduce((a, context, i, arr) => {
|
||||||
|
return `${a}${context.documents.join(' ')}\n`;
|
||||||
|
}, '');
|
||||||
|
|
||||||
history.messages[parentId].raContent = RAGTemplate(contextString, query);
|
console.log(contextString);
|
||||||
history.messages[parentId].contexts = relevantContexts;
|
|
||||||
|
history.messages[parentId].raContent = RAGTemplate(contextString, query);
|
||||||
|
history.messages[parentId].contexts = relevantContexts;
|
||||||
|
}
|
||||||
await tick();
|
await tick();
|
||||||
processing = '';
|
processing = '';
|
||||||
}
|
}
|
||||||
|
|
|
@ -246,26 +246,28 @@
|
||||||
processing = 'Reading';
|
processing = 'Reading';
|
||||||
const query = history.messages[parentId].content;
|
const query = history.messages[parentId].content;
|
||||||
|
|
||||||
let relevantContexts = await Promise.all(
|
let relevantContexts = await queryVectorDB(
|
||||||
docs.map(async (doc) => {
|
localStorage.token,
|
||||||
return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch(
|
docs.map((d) => d.collection_name),
|
||||||
(error) => {
|
query,
|
||||||
console.log(error);
|
4
|
||||||
return null;
|
).catch((error) => {
|
||||||
}
|
console.log(error);
|
||||||
);
|
return null;
|
||||||
})
|
});
|
||||||
);
|
|
||||||
relevantContexts = relevantContexts.filter((context) => context);
|
|
||||||
|
|
||||||
const contextString = relevantContexts.reduce((a, context, i, arr) => {
|
if (relevantContexts) {
|
||||||
return `${a}${context.documents.join(' ')}\n`;
|
relevantContexts = relevantContexts.filter((context) => context);
|
||||||
}, '');
|
|
||||||
|
|
||||||
console.log(contextString);
|
const contextString = relevantContexts.reduce((a, context, i, arr) => {
|
||||||
|
return `${a}${context.documents.join(' ')}\n`;
|
||||||
|
}, '');
|
||||||
|
|
||||||
history.messages[parentId].raContent = RAGTemplate(contextString, query);
|
console.log(contextString);
|
||||||
history.messages[parentId].contexts = relevantContexts;
|
|
||||||
|
history.messages[parentId].raContent = RAGTemplate(contextString, query);
|
||||||
|
history.messages[parentId].contexts = relevantContexts;
|
||||||
|
}
|
||||||
await tick();
|
await tick();
|
||||||
processing = '';
|
processing = '';
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue