diff --git a/backend/apps/web/models/tags.py b/backend/apps/web/models/tags.py index eb33b31b..ef21ca08 100644 --- a/backend/apps/web/models/tags.py +++ b/backend/apps/web/models/tags.py @@ -15,7 +15,8 @@ from apps.web.internal.db import DB class Tag(Model): - name = CharField(unique=True) + id = CharField(unique=True) + name = CharField() user_id = CharField() data = TextField(null=True) @@ -24,7 +25,8 @@ class Tag(Model): class ChatIdTag(Model): - tag_name = ForeignKeyField(Tag, backref="chat_id_tags") + id = CharField(unique=True) + tag_name = CharField() chat_id = CharField() user_id = CharField() timestamp = DateField() @@ -34,12 +36,14 @@ class ChatIdTag(Model): class TagModel(BaseModel): + id: str name: str user_id: str data: Optional[str] = None class ChatIdTagModel(BaseModel): + id: str tag_name: str chat_id: str user_id: str @@ -70,14 +74,15 @@ class TagTable: db.create_tables([Tag, ChatIdTag]) def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: - tag = TagModel(**{"user_id": user_id, "name": name}) + id = str(uuid.uuid4()) + tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) try: result = Tag.create(**tag.model_dump()) if result: return tag else: return None - except: + except Exception as e: return None def get_tag_by_name_and_user_id( @@ -86,17 +91,27 @@ class TagTable: try: tag = Tag.get(Tag.name == name, Tag.user_id == user_id) return TagModel(**model_to_dict(tag)) - except: + except Exception as e: return None def add_tag_to_chat( self, user_id: str, form_data: ChatIdTagForm - ) -> Optional[ChatTagsResponse]: + ) -> Optional[ChatIdTagModel]: tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) if tag == None: tag = self.insert_new_tag(form_data.tag_name, user_id) - chatIdTag = ChatIdTagModel(**{"user_id": user_id, "tag_name": tag.name}) + print(tag) + id = str(uuid.uuid4()) + chatIdTag = ChatIdTagModel( + **{ + "id": id, + "user_id": user_id, + "chat_id": form_data.chat_id, + "tag_name": tag.name, + "timestamp": int(time.time()), + } + ) try: result = ChatIdTag.create(**chatIdTag.model_dump()) if result: @@ -109,19 +124,17 @@ class TagTable: def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: + tag_names = [ + ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name + for chat_id_tag in ChatIdTag.select() + .where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)) + .order_by(ChatIdTag.timestamp.desc()) + ] + + print(tag_names) return [ TagModel(**model_to_dict(tag)) - for tag in Tag.select().where( - Tag.name - in [ - ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name - for chat_id_tag in ChatIdTag.select() - .where( - (ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id) - ) - .order_by(ChatIdTag.timestamp.desc()) - ] - ) + for tag in Tag.select().where(Tag.name.in_(tag_names)) ] def get_chat_ids_by_tag_name_and_user_id( @@ -152,7 +165,8 @@ class TagTable: & (ChatIdTag.chat_id == chat_id) & (ChatIdTag.user_id == user_id) ) - query.execute() # Remove the rows, return number of rows removed. + res = query.execute() # Remove the rows, return number of rows removed. + print(res) tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) if tag_count == 0: @@ -163,7 +177,8 @@ class TagTable: query.execute() # Remove the rows, return number of rows removed. return True - except: + except Exception as e: + print("delete_tag", e) return False def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index 0c9aa573..38685826 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -19,6 +19,7 @@ from apps.web.models.chats import ( from apps.web.models.tags import ( TagModel, + ChatIdTagModel, ChatIdTagForm, ChatTagsResponse, Tags, @@ -132,7 +133,8 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)): async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) - if tags: + if tags != None: + print(tags) return tags else: raise HTTPException( @@ -145,17 +147,25 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): ############################ -@router.post("/{id}/tags", response_model=Optional[ChatTagsResponse]) +@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel]) async def add_chat_tag_by_id( id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) ): - tag = Tags.add_tag_to_chat(user.id, {"tag_name": form_data.tag_name, "chat_id": id}) + tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) - if tag: - return tag + if form_data.tag_name not in tags: + tag = Tags.add_tag_to_chat(user.id, form_data) + + if tag: + return tag + else: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=ERROR_MESSAGES.NOT_FOUND, + ) else: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT() ) diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index 521e7abc..5899d582 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -12,22 +12,11 @@ export let title: string = 'Ollama Web UI'; export let shareEnabled: boolean = false; - let showShareChatModal = false; + export let tags = []; + export let addTag: Function; + export let deleteTag: Function; - let tags = [ - // { - // name: 'general' - // }, - // { - // name: 'medicine' - // }, - // { - // name: 'cooking' - // }, - // { - // name: 'education' - // } - ]; + let showShareChatModal = false; let tagName = ''; let showTagInput = false; @@ -74,16 +63,17 @@ saveAs(blob, `chat-${chat.title}.txt`); }; - const addTag = () => { - if (!tags.find((e) => e.name === tagName)) { - tags = [ - ...tags, - { - name: JSON.parse(JSON.stringify(tagName)) - } - ]; - } + const addTagHandler = () => { + // if (!tags.find((e) => e.name === tagName)) { + // tags = [ + // ...tags, + // { + // name: JSON.parse(JSON.stringify(tagName)) + // } + // ]; + // } + addTag(tagName); tagName = ''; showTagInput = false; }; @@ -126,48 +116,19 @@