feat: doc tagging

This commit is contained in:
Timothy J. Baek 2024-02-03 14:44:49 -08:00
parent 8fd1b62e04
commit 00803c92f2
10 changed files with 344 additions and 108 deletions

View file

@ -128,6 +128,51 @@ class QueryCollectionsForm(BaseModel):
k: Optional[int] = 4
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
@app.post("/query/collections")
def query_collections(
form_data: QueryCollectionsForm,
@ -147,7 +192,7 @@ def query_collections(
except:
pass
return results
return merge_and_sort_query_results(results, form_data.k)
@app.post("/web")

View file

@ -44,6 +44,16 @@ class DocumentModel(BaseModel):
####################
class DocumentResponse(BaseModel):
collection_name: str
name: str
title: str
filename: str
content: Optional[dict] = None
user_id: str
timestamp: int # timestamp in epoch
class DocumentUpdateForm(BaseModel):
name: str
title: str
@ -111,6 +121,26 @@ class DocumentsTable:
print(e)
return None
def update_doc_content_by_name(
self, name: str, updated: dict
) -> Optional[DocumentModel]:
try:
doc = self.get_doc_by_name(name)
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
query = Document.update(
content=json.dumps(doc_content),
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
print(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
try:
query = Document.delete().where((Document.name == name))

View file

@ -11,6 +11,7 @@ from apps.web.models.documents import (
DocumentForm,
DocumentUpdateForm,
DocumentModel,
DocumentResponse,
)
from utils.utils import get_current_user
@ -23,9 +24,18 @@ router = APIRouter()
############################
@router.get("/", response_model=List[DocumentModel])
@router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user)):
return Documents.get_docs()
docs = [
DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
for doc in Documents.get_docs()
]
return docs
############################
@ -33,7 +43,7 @@ async def get_documents(user=Depends(get_current_user)):
############################
@router.post("/create", response_model=Optional[DocumentModel])
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
@ -46,7 +56,12 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)
doc = Documents.insert_new_doc(user.id, form_data)
if doc:
return doc
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@ -64,12 +79,45 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_current_user)
############################
@router.get("/name/{name}", response_model=Optional[DocumentModel])
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
if doc:
return doc
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# TagDocByName
############################
class TagDocumentForm(BaseModel):
name: str
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -82,7 +130,7 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentModel])
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_current_user)
):
@ -94,7 +142,12 @@ async def update_doc_by_name(
doc = Documents.update_doc_by_name(name, form_data)
if doc:
return doc
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,