This commit is contained in:
Timothy J. Baek 2024-02-01 13:35:41 -08:00
parent 485236624f
commit 50f7b20ac2
4 changed files with 91 additions and 58 deletions

View file

@ -10,6 +10,7 @@ from fastapi import (
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil import os, shutil
from typing import List
# from chromadb.utils import embedding_functions # from chromadb.utils import embedding_functions
@ -96,19 +97,22 @@ async def get_status():
return {"status": True} return {"status": True}
@app.get("/query/{collection_name}") class QueryCollectionForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = 4
@app.post("/query/collection")
def query_collection( def query_collection(
collection_name: str, form_data: QueryCollectionForm,
query: str,
k: Optional[int] = 4,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=form_data.collection_name,
) )
result = collection.query(query_texts=[query], n_results=k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
except Exception as e: except Exception as e:
print(e) print(e)
@ -118,6 +122,34 @@ def query_collection(
) )
class QueryCollectionsForm(BaseModel):
collection_names: List[str]
query: str
k: Optional[int] = 4
@app.post("/query/collections")
def query_collections(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
results = []
for collection_name in form_data.collection_names:
try:
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)
results.append(result)
except:
pass
return results
@app.post("/web") @app.post("/web")
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"

View file

@ -66,28 +66,25 @@ export const uploadWebToVectorDB = async (token: string, collection_name: string
export const queryVectorDB = async ( export const queryVectorDB = async (
token: string, token: string,
collection_name: string, collection_names: string[],
query: string, query: string,
k: number k: number
) => { ) => {
let error = null; let error = null;
const searchParams = new URLSearchParams();
searchParams.set('query', query); const res = await fetch(`${RAG_API_BASE_URL}/query/collections`, {
if (k) { method: 'POST',
searchParams.set('k', k.toString());
}
const res = await fetch(
`${RAG_API_BASE_URL}/query/${collection_name}/?${searchParams.toString()}`,
{
method: 'GET',
headers: { headers: {
Accept: 'application/json', Accept: 'application/json',
'Content-Type': 'application/json',
authorization: `Bearer ${token}` authorization: `Bearer ${token}`
} },
} body: JSON.stringify({
) collection_names: collection_names,
query: query,
k: k
})
})
.then(async (res) => { .then(async (res) => {
if (!res.ok) throw await res.json(); if (!res.ok) throw await res.json();
return res.json(); return res.json();

View file

@ -232,16 +232,17 @@
processing = 'Reading'; processing = 'Reading';
const query = history.messages[parentId].content; const query = history.messages[parentId].content;
let relevantContexts = await Promise.all( let relevantContexts = await queryVectorDB(
docs.map(async (doc) => { localStorage.token,
return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( docs.map((d) => d.collection_name),
(error) => { query,
4
).catch((error) => {
console.log(error); console.log(error);
return null; return null;
} });
);
}) if (relevantContexts) {
);
relevantContexts = relevantContexts.filter((context) => context); relevantContexts = relevantContexts.filter((context) => context);
const contextString = relevantContexts.reduce((a, context, i, arr) => { const contextString = relevantContexts.reduce((a, context, i, arr) => {
@ -252,6 +253,7 @@
history.messages[parentId].raContent = RAGTemplate(contextString, query); history.messages[parentId].raContent = RAGTemplate(contextString, query);
history.messages[parentId].contexts = relevantContexts; history.messages[parentId].contexts = relevantContexts;
}
await tick(); await tick();
processing = ''; processing = '';
} }

View file

@ -246,16 +246,17 @@
processing = 'Reading'; processing = 'Reading';
const query = history.messages[parentId].content; const query = history.messages[parentId].content;
let relevantContexts = await Promise.all( let relevantContexts = await queryVectorDB(
docs.map(async (doc) => { localStorage.token,
return await queryVectorDB(localStorage.token, doc.collection_name, query, 4).catch( docs.map((d) => d.collection_name),
(error) => { query,
4
).catch((error) => {
console.log(error); console.log(error);
return null; return null;
} });
);
}) if (relevantContexts) {
);
relevantContexts = relevantContexts.filter((context) => context); relevantContexts = relevantContexts.filter((context) => context);
const contextString = relevantContexts.reduce((a, context, i, arr) => { const contextString = relevantContexts.reduce((a, context, i, arr) => {
@ -266,6 +267,7 @@
history.messages[parentId].raContent = RAGTemplate(contextString, query); history.messages[parentId].raContent = RAGTemplate(contextString, query);
history.messages[parentId].contexts = relevantContexts; history.messages[parentId].contexts = relevantContexts;
}
await tick(); await tick();
processing = ''; processing = '';
} }