feat: convo tagging full integration

This commit is contained in:
Timothy J. Baek 2024-01-18 02:10:07 -08:00
parent d5ed119687
commit 987685dbf9
5 changed files with 185 additions and 119 deletions

View file

@ -15,7 +15,8 @@ from apps.web.internal.db import DB
class Tag(Model): class Tag(Model):
name = CharField(unique=True) id = CharField(unique=True)
name = CharField()
user_id = CharField() user_id = CharField()
data = TextField(null=True) data = TextField(null=True)
@ -24,7 +25,8 @@ class Tag(Model):
class ChatIdTag(Model): class ChatIdTag(Model):
tag_name = ForeignKeyField(Tag, backref="chat_id_tags") id = CharField(unique=True)
tag_name = CharField()
chat_id = CharField() chat_id = CharField()
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = DateField()
@ -34,12 +36,14 @@ class ChatIdTag(Model):
class TagModel(BaseModel): class TagModel(BaseModel):
id: str
name: str name: str
user_id: str user_id: str
data: Optional[str] = None data: Optional[str] = None
class ChatIdTagModel(BaseModel): class ChatIdTagModel(BaseModel):
id: str
tag_name: str tag_name: str
chat_id: str chat_id: str
user_id: str user_id: str
@ -70,14 +74,15 @@ class TagTable:
db.create_tables([Tag, ChatIdTag]) db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]: def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
tag = TagModel(**{"user_id": user_id, "name": name}) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag.create(**tag.model_dump()) result = Tag.create(**tag.model_dump())
if result: if result:
return tag return tag
else: else:
return None return None
except: except Exception as e:
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
@ -86,17 +91,27 @@ class TagTable:
try: try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id) tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
return TagModel(**model_to_dict(tag)) return TagModel(**model_to_dict(tag))
except: except Exception as e:
return None return None
def add_tag_to_chat( def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatTagsResponse]: ) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id) tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag == None: if tag == None:
tag = self.insert_new_tag(form_data.tag_name, user_id) tag = self.insert_new_tag(form_data.tag_name, user_id)
chatIdTag = ChatIdTagModel(**{"user_id": user_id, "tag_name": tag.name}) print(tag)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
**{
"id": id,
"user_id": user_id,
"chat_id": form_data.chat_id,
"tag_name": tag.name,
"timestamp": int(time.time()),
}
)
try: try:
result = ChatIdTag.create(**chatIdTag.model_dump()) result = ChatIdTag.create(**chatIdTag.model_dump())
if result: if result:
@ -109,19 +124,17 @@ class TagTable:
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
.order_by(ChatIdTag.timestamp.desc())
]
print(tag_names)
return [ return [
TagModel(**model_to_dict(tag)) TagModel(**model_to_dict(tag))
for tag in Tag.select().where( for tag in Tag.select().where(Tag.name.in_(tag_names))
Tag.name
in [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(
(ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id)
)
.order_by(ChatIdTag.timestamp.desc())
]
)
] ]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
@ -152,7 +165,8 @@ class TagTable:
& (ChatIdTag.chat_id == chat_id) & (ChatIdTag.chat_id == chat_id)
& (ChatIdTag.user_id == user_id) & (ChatIdTag.user_id == user_id)
) )
query.execute() # Remove the rows, return number of rows removed. res = query.execute() # Remove the rows, return number of rows removed.
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id) tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0: if tag_count == 0:
@ -163,7 +177,8 @@ class TagTable:
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True return True
except: except Exception as e:
print("delete_tag", e)
return False return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool: def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:

View file

@ -19,6 +19,7 @@ from apps.web.models.chats import (
from apps.web.models.tags import ( from apps.web.models.tags import (
TagModel, TagModel,
ChatIdTagModel,
ChatIdTagForm, ChatIdTagForm,
ChatTagsResponse, ChatTagsResponse,
Tags, Tags,
@ -132,7 +133,8 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)): async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags: if tags != None:
print(tags)
return tags return tags
else: else:
raise HTTPException( raise HTTPException(
@ -145,17 +147,25 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
############################ ############################
@router.post("/{id}/tags", response_model=Optional[ChatTagsResponse]) @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id( async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user) id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
): ):
tag = Tags.add_tag_to_chat(user.id, {"tag_name": form_data.tag_name, "chat_id": id}) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tag: if form_data.tag_name not in tags:
return tag tag = Tags.add_tag_to_chat(user.id, form_data)
if tag:
return tag
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
) )

View file

@ -12,22 +12,11 @@
export let title: string = 'Ollama Web UI'; export let title: string = 'Ollama Web UI';
export let shareEnabled: boolean = false; export let shareEnabled: boolean = false;
let showShareChatModal = false; export let tags = [];
export let addTag: Function;
export let deleteTag: Function;
let tags = [ let showShareChatModal = false;
// {
// name: 'general'
// },
// {
// name: 'medicine'
// },
// {
// name: 'cooking'
// },
// {
// name: 'education'
// }
];
let tagName = ''; let tagName = '';
let showTagInput = false; let showTagInput = false;
@ -74,16 +63,17 @@
saveAs(blob, `chat-${chat.title}.txt`); saveAs(blob, `chat-${chat.title}.txt`);
}; };
const addTag = () => { const addTagHandler = () => {
if (!tags.find((e) => e.name === tagName)) { // if (!tags.find((e) => e.name === tagName)) {
tags = [ // tags = [
...tags, // ...tags,
{ // {
name: JSON.parse(JSON.stringify(tagName)) // name: JSON.parse(JSON.stringify(tagName))
} // }
]; // ];
} // }
addTag(tagName);
tagName = ''; tagName = '';
showTagInput = false; showTagInput = false;
}; };
@ -126,48 +116,19 @@
</div> </div>
<div class="pl-2 self-center flex items-center space-x-2"> <div class="pl-2 self-center flex items-center space-x-2">
<div class="flex flex-row space-x-0.5 line-clamp-1"> {#if shareEnabled}
{#each tags as tag} <div class="flex flex-row space-x-0.5 line-clamp-1">
<div {#each tags as tag}
class="px-2 py-0.5 space-x-1 flex h-fit items-center rounded-full transition border dark:border-gray-600 dark:text-white" <div
> class="px-2 py-0.5 space-x-1 flex h-fit items-center rounded-full transition border dark:border-gray-600 dark:text-white"
<div class=" text-[0.65rem] font-medium self-center line-clamp-1">
{tag.name}
</div>
<button
class=" m-auto self-center cursor-pointer"
on:click={() => {
console.log(tag.name);
tags = tags.filter((t) => t.name !== tag.name);
}}
> >
<svg <div class=" text-[0.65rem] font-medium self-center line-clamp-1">
xmlns="http://www.w3.org/2000/svg" {tag.name}
viewBox="0 0 16 16" </div>
fill="currentColor"
class="w-3 h-3"
>
<path
d="M5.28 4.22a.75.75 0 0 0-1.06 1.06L6.94 8l-2.72 2.72a.75.75 0 1 0 1.06 1.06L8 9.06l2.72 2.72a.75.75 0 1 0 1.06-1.06L9.06 8l2.72-2.72a.75.75 0 0 0-1.06-1.06L8 6.94 5.28 4.22Z"
/>
</svg>
</button>
</div>
{/each}
<div class="flex space-x-1 pl-1.5">
{#if showTagInput}
<div class="flex items-center">
<input
bind:value={tagName}
class=" cursor-pointer self-center text-xs h-fit bg-transparent outline-none line-clamp-1 w-[4rem]"
placeholder="Add a tag"
/>
<button <button
class=" m-auto self-center cursor-pointer"
on:click={() => { on:click={() => {
addTag(); deleteTag(tag.name);
}} }}
> >
<svg <svg
@ -177,40 +138,67 @@
class="w-3 h-3" class="w-3 h-3"
> >
<path <path
fill-rule="evenodd" d="M5.28 4.22a.75.75 0 0 0-1.06 1.06L6.94 8l-2.72 2.72a.75.75 0 1 0 1.06 1.06L8 9.06l2.72 2.72a.75.75 0 1 0 1.06-1.06L9.06 8l2.72-2.72a.75.75 0 0 0-1.06-1.06L8 6.94 5.28 4.22Z"
d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
clip-rule="evenodd"
/> />
</svg> </svg>
</button> </button>
</div> </div>
{/each}
<!-- TODO: Tag Suggestions --> <div class="flex space-x-1 pl-1.5">
{/if} {#if showTagInput}
<div class="flex items-center">
<button <input
class=" cursor-pointer self-center p-0.5 space-x-1 flex h-fit items-center dark:hover:bg-gray-700 rounded-full transition border dark:border-gray-600 border-dashed" bind:value={tagName}
on:click={() => { class=" cursor-pointer self-center text-xs h-fit bg-transparent outline-none line-clamp-1 w-[4rem]"
showTagInput = !showTagInput; placeholder="Add a tag"
}}
>
<div class=" m-auto self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3 h-3 {showTagInput ? 'rotate-45' : ''} transition-all transform"
>
<path
d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
/> />
</svg>
</div>
</button>
</div>
</div>
{#if shareEnabled} <button
on:click={() => {
addTagHandler();
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3 h-3"
>
<path
fill-rule="evenodd"
d="M12.416 3.376a.75.75 0 0 1 .208 1.04l-5 7.5a.75.75 0 0 1-1.154.114l-3-3a.75.75 0 0 1 1.06-1.06l2.353 2.353 4.493-6.74a.75.75 0 0 1 1.04-.207Z"
clip-rule="evenodd"
/>
</svg>
</button>
</div>
<!-- TODO: Tag Suggestions -->
{/if}
<button
class=" cursor-pointer self-center p-0.5 space-x-1 flex h-fit items-center dark:hover:bg-gray-700 rounded-full transition border dark:border-gray-600 border-dashed"
on:click={() => {
showTagInput = !showTagInput;
}}
>
<div class=" m-auto self-center">
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-3 h-3 {showTagInput ? 'rotate-45' : ''} transition-all transform"
>
<path
d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
/>
</svg>
</div>
</button>
</div>
</div>
<button <button
class=" cursor-pointer p-1.5 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600" class=" cursor-pointer p-1.5 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"
on:click={async () => { on:click={async () => {

View file

@ -10,7 +10,14 @@
import { copyToClipboard, splitStream } from '$lib/utils'; import { copyToClipboard, splitStream } from '$lib/utils';
import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats'; import {
addTagById,
createNewChat,
deleteTagById,
getChatList,
getTagsById,
updateChatById
} from '$lib/apis/chats';
import { queryVectorDB } from '$lib/apis/rag'; import { queryVectorDB } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
@ -47,6 +54,7 @@
}, {}); }, {});
let chat = null; let chat = null;
let tags = [];
let title = ''; let title = '';
let prompt = ''; let prompt = '';
@ -673,6 +681,22 @@
} }
}; };
const getTags = async () => {
return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
return [];
});
};
const addTag = async (tagName) => {
const res = await addTagById(localStorage.token, $chatId, tagName);
tags = await getTags();
};
const deleteTag = async (tagName) => {
const res = await deleteTagById(localStorage.token, $chatId, tagName);
tags = await getTags();
};
const setChatTitle = async (_chatId, _title) => { const setChatTitle = async (_chatId, _title) => {
if (_chatId === $chatId) { if (_chatId === $chatId) {
title = _title; title = _title;
@ -691,7 +715,7 @@
}} }}
/> />
<Navbar {title} shareEnabled={messages.length > 0} {initNewChat} /> <Navbar {title} shareEnabled={messages.length > 0} {initNewChat} {tags} {addTag} {deleteTag} />
<div class="min-h-screen w-full flex justify-center"> <div class="min-h-screen w-full flex justify-center">
<div class=" py-2.5 flex flex-col justify-between w-full"> <div class=" py-2.5 flex flex-col justify-between w-full">
<div class="max-w-2xl mx-auto w-full px-3 md:px-0 mt-10"> <div class="max-w-2xl mx-auto w-full px-3 md:px-0 mt-10">

View file

@ -10,7 +10,15 @@
import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils'; import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
import { createNewChat, getChatById, getChatList, updateChatById } from '$lib/apis/chats'; import {
addTagById,
createNewChat,
deleteTagById,
getChatById,
getChatList,
getTagsById,
updateChatById
} from '$lib/apis/chats';
import { queryVectorDB } from '$lib/apis/rag'; import { queryVectorDB } from '$lib/apis/rag';
import { generateOpenAIChatCompletion } from '$lib/apis/openai'; import { generateOpenAIChatCompletion } from '$lib/apis/openai';
@ -49,6 +57,7 @@
}, {}); }, {});
let chat = null; let chat = null;
let tags = [];
let title = ''; let title = '';
let prompt = ''; let prompt = '';
@ -97,6 +106,7 @@
}); });
if (chat) { if (chat) {
tags = await getTags();
const chatContent = chat.chat; const chatContent = chat.chat;
if (chatContent) { if (chatContent) {
@ -688,6 +698,22 @@
await chats.set(await getChatList(localStorage.token)); await chats.set(await getChatList(localStorage.token));
}; };
const getTags = async () => {
return await getTagsById(localStorage.token, $chatId).catch(async (error) => {
return [];
});
};
const addTag = async (tagName) => {
const res = await addTagById(localStorage.token, $chatId, tagName);
tags = await getTags();
};
const deleteTag = async (tagName) => {
const res = await deleteTagById(localStorage.token, $chatId, tagName);
tags = await getTags();
};
onMount(async () => { onMount(async () => {
if (!($settings.saveChatHistory ?? true)) { if (!($settings.saveChatHistory ?? true)) {
await goto('/'); await goto('/');
@ -713,6 +739,9 @@
goto('/'); goto('/');
}} }}
{tags}
{addTag}
{deleteTag}
/> />
<div class="min-h-screen w-full flex justify-center"> <div class="min-h-screen w-full flex justify-center">
<div class=" py-2.5 flex flex-col justify-between w-full"> <div class=" py-2.5 flex flex-col justify-between w-full">