This commit is contained in:
Timothy J. Baek 2024-02-01 13:35:41 -08:00
parent 485236624f
commit 50f7b20ac2
4 changed files with 91 additions and 58 deletions

View file

@ -10,6 +10,7 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from typing import List
# from chromadb.utils import embedding_functions
@ -96,19 +97,22 @@ async def get_status():
return {"status": True}
@app.get("/query/{collection_name}")
class QueryCollectionForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = 4
@app.post("/query/collection")
def query_collection(
collection_name: str,
query: str,
k: Optional[int] = 4,
form_data: QueryCollectionForm,
user=Depends(get_current_user),
):
try:
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
name=form_data.collection_name,
)
result = collection.query(query_texts=[query], n_results=k)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
except Exception as e:
print(e)
@ -118,6 +122,34 @@ def query_collection(
)
class QueryCollectionsForm(BaseModel):
collection_names: List[str]
query: str
k: Optional[int] = 4
@app.post("/query/collections")
def query_collections(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
results = []
for collection_name in form_data.collection_names:
try:
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)
results.append(result)
except:
pass
return results
@app.post("/web")
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"