feat: convo tag filtering

This commit is contained in:
Timothy J. Baek 2024-01-18 02:55:25 -08:00
parent 1eec176313
commit 220530c450
10 changed files with 214 additions and 27 deletions

View file

@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str,
form_data: ChatForm) -> Optional[ChatModel]:
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": form_data.chat["title"] if "title" in
form_data.chat else "New Chat",
"title": form_data.chat["title"]
if "title" in form_data.chat
else "New Chat",
"chat": json.dumps(form_data.chat),
"timestamp": int(time.time()),
})
}
)
result = Chat.create(**chat.model_dump())
return chat if result else None
@ -109,25 +109,37 @@ class ChatTable:
except:
return None
def get_chat_lists_by_user_id(self,
user_id: str,
skip: int = 0,
limit: int = 50) -> List[ChatModel]:
def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
Chat.user_id == user_id).order_by(Chat.timestamp.desc())
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
# .limit(limit)
# .offset(skip)
]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
def get_chat_lists_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
Chat.user_id == user_id).order_by(Chat.timestamp.desc())
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.id.in_(chat_ids))
.order_by(Chat.timestamp.desc())
]
def get_chat_by_id_and_user_id(self, id: str,
user_id: str) -> Optional[ChatModel]:
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
]
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat))
@ -142,8 +154,7 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id)
& (Chat.user_id == user_id))
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed.
return True

View file

@ -120,6 +120,19 @@ class TagTable:
except:
return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(ChatIdTag.user_id == user_id)
.order_by(ChatIdTag.timestamp.desc())
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select().where(Tag.name.in_(tag_names))
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> List[TagModel]:

View file

@ -74,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chats_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
print(chat_ids)
return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
############################
# GetChatById
############################