diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index bc4659de..4895740d 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel): class ChatTable: - def __init__(self, db): self.db = db db.create_tables([Chat]) - def insert_new_chat(self, user_id: str, - form_data: ChatForm) -> Optional[ChatModel]: + def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: id = str(uuid.uuid4()) chat = ChatModel( **{ "id": id, "user_id": user_id, - "title": form_data.chat["title"] if "title" in - form_data.chat else "New Chat", + "title": form_data.chat["title"] + if "title" in form_data.chat + else "New Chat", "chat": json.dumps(form_data.chat), "timestamp": int(time.time()), - }) + } + ) result = Chat.create(**chat.model_dump()) return chat if result else None @@ -109,25 +109,37 @@ class ChatTable: except: return None - def get_chat_lists_by_user_id(self, - user_id: str, - skip: int = 0, - limit: int = 50) -> List[ChatModel]: + def get_chat_lists_by_user_id( + self, user_id: str, skip: int = 0, limit: int = 50 + ) -> List[ChatModel]: return [ - ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( - Chat.user_id == user_id).order_by(Chat.timestamp.desc()) + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.timestamp.desc()) # .limit(limit) # .offset(skip) ] - def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + def get_chat_lists_by_chat_ids( + self, chat_ids: List[str], skip: int = 0, limit: int = 50 + ) -> List[ChatModel]: return [ - ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( - Chat.user_id == user_id).order_by(Chat.timestamp.desc()) + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.id.in_(chat_ids)) + .order_by(Chat.timestamp.desc()) ] - def get_chat_by_id_and_user_id(self, id: str, - user_id: str) -> Optional[ChatModel]: + def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: + return [ + ChatModel(**model_to_dict(chat)) + for chat in Chat.select() + .where(Chat.user_id == user_id) + .order_by(Chat.timestamp.desc()) + ] + + def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: try: chat = Chat.get(Chat.id == id, Chat.user_id == user_id) return ChatModel(**model_to_dict(chat)) @@ -142,8 +154,7 @@ class ChatTable: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: try: - query = Chat.delete().where((Chat.id == id) - & (Chat.user_id == user_id)) + query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query.execute() # Remove the rows, return number of rows removed. return True diff --git a/backend/apps/web/models/tags.py b/backend/apps/web/models/tags.py index 5c6c094e..c14658cf 100644 --- a/backend/apps/web/models/tags.py +++ b/backend/apps/web/models/tags.py @@ -120,6 +120,19 @@ class TagTable: except: return None + def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: + tag_names = [ + ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name + for chat_id_tag in ChatIdTag.select() + .where(ChatIdTag.user_id == user_id) + .order_by(ChatIdTag.timestamp.desc()) + ] + + return [ + TagModel(**model_to_dict(tag)) + for tag in Tag.select().where(Tag.name.in_(tag_names)) + ] + def get_tags_by_chat_id_and_user_id( self, chat_id: str, user_id: str ) -> List[TagModel]: diff --git a/backend/apps/web/routers/chats.py b/backend/apps/web/routers/chats.py index c4d9ebab..29214229 100644 --- a/backend/apps/web/routers/chats.py +++ b/backend/apps/web/routers/chats.py @@ -74,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ) +############################ +# GetAllTags +############################ + + +@router.get("/tags/all", response_model=List[TagModel]) +async def get_all_tags(user=Depends(get_current_user)): + try: + tags = Tags.get_tags_by_user_id(user.id) + return tags + except Exception as e: + print(e) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() + ) + + +############################ +# GetChatsByTags +############################ + + +@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse]) +async def get_user_chats_by_tag_name( + tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50 +): + chat_ids = [ + chat_id_tag.chat_id + for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id) + ] + + print(chat_ids) + + return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit) + + ############################ # GetChatById ############################ diff --git a/src/lib/apis/chats/index.ts b/src/lib/apis/chats/index.ts index b7f01c6e..7c515f11 100644 --- a/src/lib/apis/chats/index.ts +++ b/src/lib/apis/chats/index.ts @@ -93,6 +93,68 @@ export const getAllChats = async (token: string) => { return res; }; +export const getAllChatTags = async (token: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + +export const getChatListByTagName = async (token: string = '', tagName: string) => { + let error = null; + + const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, { + method: 'GET', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + ...(token && { authorization: `Bearer ${token}` }) + } + }) + .then(async (res) => { + if (!res.ok) throw await res.json(); + return res.json(); + }) + .then((json) => { + return json; + }) + .catch((err) => { + error = err; + console.log(err); + return null; + }); + + if (error) { + throw error; + } + + return res; +}; + export const getChatById = async (token: string, id: string) => { let error = null; diff --git a/src/lib/components/layout/Navbar.svelte b/src/lib/components/layout/Navbar.svelte index 5899d582..3fc0ffa0 100644 --- a/src/lib/components/layout/Navbar.svelte +++ b/src/lib/components/layout/Navbar.svelte @@ -6,7 +6,6 @@ import { getChatById } from '$lib/apis/chats'; import { chatId, modelfiles } from '$lib/stores'; import ShareChatModal from '../chat/ShareChatModal.svelte'; - import { stringify } from 'postcss'; export let initNewChat: Function; export let title: string = 'Ollama Web UI'; diff --git a/src/lib/components/layout/Sidebar.svelte b/src/lib/components/layout/Sidebar.svelte index 927d87d6..afd4b0e5 100644 --- a/src/lib/components/layout/Sidebar.svelte +++ b/src/lib/components/layout/Sidebar.svelte @@ -6,9 +6,14 @@ import { goto, invalidateAll } from '$app/navigation'; import { page } from '$app/stores'; - import { user, chats, settings, showSettings, chatId } from '$lib/stores'; + import { user, chats, settings, showSettings, chatId, tags } from '$lib/stores'; import { onMount } from 'svelte'; - import { deleteChatById, getChatList, updateChatById } from '$lib/apis/chats'; + import { + deleteChatById, + getChatList, + getChatListByTagName, + updateChatById + } from '$lib/apis/chats'; let show = false; let navElement; @@ -28,6 +33,12 @@ } await chats.set(await getChatList(localStorage.token)); + + tags.subscribe(async (value) => { + if (value.length === 0) { + await chats.set(await getChatList(localStorage.token)); + } + }); }); const loadChat = async (id) => { @@ -281,6 +292,29 @@ + {#if $tags.length > 0} +
+ + {#each $tags as tag} + + {/each} +
+ {/if} +
{#each $chats.filter((chat) => { if (search === '') { diff --git a/src/lib/stores/index.ts b/src/lib/stores/index.ts index c7d8f5e6..7880235c 100644 --- a/src/lib/stores/index.ts +++ b/src/lib/stores/index.ts @@ -10,6 +10,7 @@ export const theme = writable('dark'); export const chatId = writable(''); export const chats = writable([]); +export const tags = writable([]); export const models = writable([]); export const modelfiles = writable([]); diff --git a/src/routes/(app)/+layout.svelte b/src/routes/(app)/+layout.svelte index 39ae0eea..c7839d93 100644 --- a/src/routes/(app)/+layout.svelte +++ b/src/routes/(app)/+layout.svelte @@ -20,7 +20,8 @@ models, modelfiles, prompts, - documents + documents, + tags } from '$lib/stores'; import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants'; @@ -29,6 +30,7 @@ import { checkVersion } from '$lib/utils'; import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte'; import { getDocs } from '$lib/apis/documents'; + import { getAllChatTags } from '$lib/apis/chats'; let ollamaVersion = ''; let loaded = false; @@ -106,6 +108,7 @@ await modelfiles.set(await getModelfiles(localStorage.token)); await prompts.set(await getPrompts(localStorage.token)); await documents.set(await getDocs(localStorage.token)); + await tags.set(await getAllChatTags(localStorage.token)); modelfiles.subscribe(async () => { // should fetch models diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 29e4f201..9507579c 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -6,7 +6,16 @@ import { goto } from '$app/navigation'; import { page } from '$app/stores'; - import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; + import { + models, + modelfiles, + user, + settings, + chats, + chatId, + config, + tags as _tags + } from '$lib/stores'; import { copyToClipboard, splitStream } from '$lib/utils'; import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; @@ -14,6 +23,7 @@ addTagById, createNewChat, deleteTagById, + getAllChatTags, getChatList, getTagsById, updateChatById @@ -695,6 +705,8 @@ chat = await updateChatById(localStorage.token, $chatId, { tags: tags }); + + _tags.set(await getAllChatTags(localStorage.token)); }; const deleteTag = async (tagName) => { @@ -704,6 +716,8 @@ chat = await updateChatById(localStorage.token, $chatId, { tags: tags }); + + _tags.set(await getAllChatTags(localStorage.token)); }; const setChatTitle = async (_chatId, _title) => { diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 37f6f39c..206b7398 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -6,7 +6,16 @@ import { goto } from '$app/navigation'; import { page } from '$app/stores'; - import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; + import { + models, + modelfiles, + user, + settings, + chats, + chatId, + config, + tags as _tags + } from '$lib/stores'; import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils'; import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; @@ -14,6 +23,7 @@ addTagById, createNewChat, deleteTagById, + getAllChatTags, getChatById, getChatList, getTagsById, @@ -709,8 +719,10 @@ tags = await getTags(); chat = await updateChatById(localStorage.token, $chatId, { - tags: tags.map((tag) => tag.name) + tags: tags }); + + _tags.set(await getAllChatTags(localStorage.token)); }; const deleteTag = async (tagName) => { @@ -718,8 +730,10 @@ tags = await getTags(); chat = await updateChatById(localStorage.token, $chatId, { - tags: tags.map((tag) => tag.name) + tags: tags }); + + _tags.set(await getAllChatTags(localStorage.token)); }; onMount(async () => {