forked from open-webui/open-webui
feat: convo tagging backend support
This commit is contained in:
parent
287668f84e
commit
077f1fa34b
2 changed files with 245 additions and 0 deletions
180
backend/apps/web/models/tags.py
Normal file
180
backend/apps/web/models/tags.py
Normal file
|
@ -0,0 +1,180 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List, Union, Optional
|
||||||
|
from peewee import *
|
||||||
|
from playhouse.shortcuts import model_to_dict
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
import time
|
||||||
|
|
||||||
|
from apps.web.internal.db import DB
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Tag DB Schema
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
class Tag(Model):
|
||||||
|
name = CharField(unique=True)
|
||||||
|
user_id = CharField()
|
||||||
|
data = TextField(null=True)
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
database = DB
|
||||||
|
|
||||||
|
|
||||||
|
class ChatIdTag(Model):
|
||||||
|
tag_name = ForeignKeyField(Tag, backref="chat_id_tags")
|
||||||
|
chat_id = CharField()
|
||||||
|
user_id = CharField()
|
||||||
|
timestamp = DateField()
|
||||||
|
|
||||||
|
class Meta:
|
||||||
|
database = DB
|
||||||
|
|
||||||
|
|
||||||
|
class TagModel(BaseModel):
|
||||||
|
name: str
|
||||||
|
user_id: str
|
||||||
|
data: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatIdTagModel(BaseModel):
|
||||||
|
tag_name: str
|
||||||
|
chat_id: str
|
||||||
|
user_id: str
|
||||||
|
timestamp: int
|
||||||
|
|
||||||
|
|
||||||
|
####################
|
||||||
|
# Forms
|
||||||
|
####################
|
||||||
|
|
||||||
|
|
||||||
|
class ChatIdTagForm(BaseModel):
|
||||||
|
tag_name: str
|
||||||
|
chat_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class TagChatIdsResponse(BaseModel):
|
||||||
|
chat_ids: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatTagsResponse(BaseModel):
|
||||||
|
tags: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TagTable:
|
||||||
|
def __init__(self, db):
|
||||||
|
self.db = db
|
||||||
|
db.create_tables([Tag, ChatIdTag])
|
||||||
|
|
||||||
|
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
|
||||||
|
tag = TagModel(**{"user_id": user_id, "name": name})
|
||||||
|
try:
|
||||||
|
result = Tag.create(**tag.model_dump())
|
||||||
|
if result:
|
||||||
|
return tag
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tag_by_name_and_user_id(
|
||||||
|
self, name: str, user_id: str
|
||||||
|
) -> Optional[TagModel]:
|
||||||
|
try:
|
||||||
|
tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
|
||||||
|
return TagModel(**model_to_dict(tag))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def add_tag_to_chat(
|
||||||
|
self, user_id: str, form_data: ChatIdTagForm
|
||||||
|
) -> Optional[ChatTagsResponse]:
|
||||||
|
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
|
||||||
|
if tag == None:
|
||||||
|
tag = self.insert_new_tag(form_data.tag_name, user_id)
|
||||||
|
|
||||||
|
chatIdTag = ChatIdTagModel(**{"user_id": user_id, "tag_name": tag.name})
|
||||||
|
try:
|
||||||
|
result = ChatIdTag.create(**chatIdTag.model_dump())
|
||||||
|
if result:
|
||||||
|
return chatIdTag
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_tags_by_chat_id_and_user_id(
|
||||||
|
self, chat_id: str, user_id: str
|
||||||
|
) -> List[TagModel]:
|
||||||
|
return [
|
||||||
|
TagModel(**model_to_dict(tag))
|
||||||
|
for tag in Tag.select().where(
|
||||||
|
Tag.name
|
||||||
|
in [
|
||||||
|
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
|
||||||
|
for chat_id_tag in ChatIdTag.select()
|
||||||
|
.where(
|
||||||
|
(ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)
|
||||||
|
)
|
||||||
|
.order_by(ChatIdTag.timestamp.desc())
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_chat_ids_by_tag_name_and_user_id(
|
||||||
|
self, tag_name: str, user_id: str
|
||||||
|
) -> Optional[ChatIdTagModel]:
|
||||||
|
return [
|
||||||
|
ChatIdTagModel(**model_to_dict(chat_id_tag))
|
||||||
|
for chat_id_tag in ChatIdTag.select()
|
||||||
|
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
|
||||||
|
.order_by(ChatIdTag.timestamp.desc())
|
||||||
|
]
|
||||||
|
|
||||||
|
def count_chat_ids_by_tag_name_and_user_id(
|
||||||
|
self, tag_name: str, user_id: str
|
||||||
|
) -> int:
|
||||||
|
return (
|
||||||
|
ChatIdTag.select()
|
||||||
|
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||||
|
self, tag_name: str, chat_id: str, user_id: str
|
||||||
|
) -> bool:
|
||||||
|
try:
|
||||||
|
query = ChatIdTag.delete().where(
|
||||||
|
(ChatIdTag.tag_name == tag_name)
|
||||||
|
& (ChatIdTag.chat_id == chat_id)
|
||||||
|
& (ChatIdTag.user_id == user_id)
|
||||||
|
)
|
||||||
|
query.execute() # Remove the rows, return number of rows removed.
|
||||||
|
|
||||||
|
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
|
||||||
|
if tag_count == 0:
|
||||||
|
# Remove tag item from Tag col as well
|
||||||
|
query = Tag.delete().where(
|
||||||
|
(Tag.name == tag_name) & (Tag.user_id == user_id)
|
||||||
|
)
|
||||||
|
query.execute() # Remove the rows, return number of rows removed.
|
||||||
|
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
|
||||||
|
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
|
||||||
|
|
||||||
|
for tag in tags:
|
||||||
|
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||||
|
tag.tag_name, chat_id, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
Tags = TagTable(DB)
|
|
@ -16,6 +16,14 @@ from apps.web.models.chats import (
|
||||||
Chats,
|
Chats,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
from apps.web.models.tags import (
|
||||||
|
TagModel,
|
||||||
|
ChatIdTagForm,
|
||||||
|
ChatTagsResponse,
|
||||||
|
Tags,
|
||||||
|
)
|
||||||
|
|
||||||
from utils.utils import (
|
from utils.utils import (
|
||||||
bearer_scheme,
|
bearer_scheme,
|
||||||
)
|
)
|
||||||
|
@ -115,6 +123,63 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetChatTagsById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{id}/tags", response_model=List[TagModel])
|
||||||
|
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
|
||||||
|
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
return tags
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# AddChatTagById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{id}/tags", response_model=Optional[ChatTagsResponse])
|
||||||
|
async def add_chat_tag_by_id(
|
||||||
|
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
|
||||||
|
):
|
||||||
|
tag = Tags.add_tag_to_chat(user.id, {"tag_name": form_data.tag_name, "chat_id": id})
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
return tag
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# DeleteChatTagById
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{id}/tags", response_model=Optional[bool])
|
||||||
|
async def add_chat_tag_by_id(
|
||||||
|
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
|
||||||
|
):
|
||||||
|
tag = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
|
||||||
|
form_data.tag_name, id, user.id
|
||||||
|
)
|
||||||
|
|
||||||
|
if tag:
|
||||||
|
return tag
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# DeleteAllChats
|
# DeleteAllChats
|
||||||
############################
|
############################
|
||||||
|
|
Loading…
Reference in a new issue