feat: convo tag filtering

This commit is contained in:
Timothy J. Baek 2024-01-18 02:55:25 -08:00
parent 1eec176313
commit 220530c450
10 changed files with 214 additions and 27 deletions

View file

@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
db.create_tables([Chat]) db.create_tables([Chat])
def insert_new_chat(self, user_id: str, def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": form_data.chat["title"] if "title" in "title": form_data.chat["title"]
form_data.chat else "New Chat", if "title" in form_data.chat
else "New Chat",
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"timestamp": int(time.time()), "timestamp": int(time.time()),
}) }
)
result = Chat.create(**chat.model_dump()) result = Chat.create(**chat.model_dump())
return chat if result else None return chat if result else None
@ -109,25 +109,37 @@ class ChatTable:
except: except:
return None return None
def get_chat_lists_by_user_id(self, def get_chat_lists_by_user_id(
user_id: str, self, user_id: str, skip: int = 0, limit: int = 50
skip: int = 0, ) -> List[ChatModel]:
limit: int = 50) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( ChatModel(**model_to_dict(chat))
Chat.user_id == user_id).order_by(Chat.timestamp.desc()) for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
# .limit(limit) # .limit(limit)
# .offset(skip) # .offset(skip)
] ]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_chat_lists_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( ChatModel(**model_to_dict(chat))
Chat.user_id == user_id).order_by(Chat.timestamp.desc()) for chat in Chat.select()
.where(Chat.id.in_(chat_ids))
.order_by(Chat.timestamp.desc())
] ]
def get_chat_by_id_and_user_id(self, id: str, def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
user_id: str) -> Optional[ChatModel]: return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
]
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id) chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat)) return ChatModel(**model_to_dict(chat))
@ -142,8 +154,7 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
& (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True return True

View file

@ -120,6 +120,19 @@ class TagTable:
except: except:
return None return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(ChatIdTag.user_id == user_id)
.order_by(ChatIdTag.timestamp.desc())
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select().where(Tag.name.in_(tag_names))
]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:

View file

@ -74,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
) )
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chats_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
print(chat_ids)
return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
############################ ############################
# GetChatById # GetChatById
############################ ############################

View file

@ -93,6 +93,68 @@ export const getAllChats = async (token: string) => {
return res; return res;
}; };
export const getAllChatTags = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getChatListByTagName = async (token: string = '', tagName: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const getChatById = async (token: string, id: string) => { export const getChatById = async (token: string, id: string) => {
let error = null; let error = null;

View file

@ -6,7 +6,6 @@
import { getChatById } from '$lib/apis/chats'; import { getChatById } from '$lib/apis/chats';
import { chatId, modelfiles } from '$lib/stores'; import { chatId, modelfiles } from '$lib/stores';
import ShareChatModal from '../chat/ShareChatModal.svelte'; import ShareChatModal from '../chat/ShareChatModal.svelte';
import { stringify } from 'postcss';
export let initNewChat: Function; export let initNewChat: Function;
export let title: string = 'Ollama Web UI'; export let title: string = 'Ollama Web UI';

View file

@ -6,9 +6,14 @@
import { goto, invalidateAll } from '$app/navigation'; import { goto, invalidateAll } from '$app/navigation';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { user, chats, settings, showSettings, chatId } from '$lib/stores'; import { user, chats, settings, showSettings, chatId, tags } from '$lib/stores';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import { deleteChatById, getChatList, updateChatById } from '$lib/apis/chats'; import {
deleteChatById,
getChatList,
getChatListByTagName,
updateChatById
} from '$lib/apis/chats';
let show = false; let show = false;
let navElement; let navElement;
@ -28,6 +33,12 @@
} }
await chats.set(await getChatList(localStorage.token)); await chats.set(await getChatList(localStorage.token));
tags.subscribe(async (value) => {
if (value.length === 0) {
await chats.set(await getChatList(localStorage.token));
}
});
}); });
const loadChat = async (id) => { const loadChat = async (id) => {
@ -281,6 +292,29 @@
</div> </div>
</div> </div>
{#if $tags.length > 0}
<div class="px-2.5 mt-0.5 mb-2 flex gap-1 flex-wrap">
<button
class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full"
on:click={async () => {
await chats.set(await getChatList(localStorage.token));
}}
>
all
</button>
{#each $tags as tag}
<button
class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full"
on:click={async () => {
await chats.set(await getChatListByTagName(localStorage.token, tag.name));
}}
>
{tag.name}
</button>
{/each}
</div>
{/if}
<div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto"> <div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto">
{#each $chats.filter((chat) => { {#each $chats.filter((chat) => {
if (search === '') { if (search === '') {

View file

@ -10,6 +10,7 @@ export const theme = writable('dark');
export const chatId = writable(''); export const chatId = writable('');
export const chats = writable([]); export const chats = writable([]);
export const tags = writable([]);
export const models = writable([]); export const models = writable([]);
export const modelfiles = writable([]); export const modelfiles = writable([]);

View file

@ -20,7 +20,8 @@
models, models,
modelfiles, modelfiles,
prompts, prompts,
documents documents,
tags
} from '$lib/stores'; } from '$lib/stores';
import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants'; import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
@ -29,6 +30,7 @@
import { checkVersion } from '$lib/utils'; import { checkVersion } from '$lib/utils';
import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte'; import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
import { getDocs } from '$lib/apis/documents'; import { getDocs } from '$lib/apis/documents';
import { getAllChatTags } from '$lib/apis/chats';
let ollamaVersion = ''; let ollamaVersion = '';
let loaded = false; let loaded = false;
@ -106,6 +108,7 @@
await modelfiles.set(await getModelfiles(localStorage.token)); await modelfiles.set(await getModelfiles(localStorage.token));
await prompts.set(await getPrompts(localStorage.token)); await prompts.set(await getPrompts(localStorage.token));
await documents.set(await getDocs(localStorage.token)); await documents.set(await getDocs(localStorage.token));
await tags.set(await getAllChatTags(localStorage.token));
modelfiles.subscribe(async () => { modelfiles.subscribe(async () => {
// should fetch models // should fetch models

View file

@ -6,7 +6,16 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; import {
models,
modelfiles,
user,
settings,
chats,
chatId,
config,
tags as _tags
} from '$lib/stores';
import { copyToClipboard, splitStream } from '$lib/utils'; import { copyToClipboard, splitStream } from '$lib/utils';
import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
@ -14,6 +23,7 @@
addTagById, addTagById,
createNewChat, createNewChat,
deleteTagById, deleteTagById,
getAllChatTags,
getChatList, getChatList,
getTagsById, getTagsById,
updateChatById updateChatById
@ -695,6 +705,8 @@
chat = await updateChatById(localStorage.token, $chatId, { chat = await updateChatById(localStorage.token, $chatId, {
tags: tags tags: tags
}); });
_tags.set(await getAllChatTags(localStorage.token));
}; };
const deleteTag = async (tagName) => { const deleteTag = async (tagName) => {
@ -704,6 +716,8 @@
chat = await updateChatById(localStorage.token, $chatId, { chat = await updateChatById(localStorage.token, $chatId, {
tags: tags tags: tags
}); });
_tags.set(await getAllChatTags(localStorage.token));
}; };
const setChatTitle = async (_chatId, _title) => { const setChatTitle = async (_chatId, _title) => {

View file

@ -6,7 +6,16 @@
import { goto } from '$app/navigation'; import { goto } from '$app/navigation';
import { page } from '$app/stores'; import { page } from '$app/stores';
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; import {
models,
modelfiles,
user,
settings,
chats,
chatId,
config,
tags as _tags
} from '$lib/stores';
import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils'; import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
@ -14,6 +23,7 @@
addTagById, addTagById,
createNewChat, createNewChat,
deleteTagById, deleteTagById,
getAllChatTags,
getChatById, getChatById,
getChatList, getChatList,
getTagsById, getTagsById,
@ -709,8 +719,10 @@
tags = await getTags(); tags = await getTags();
chat = await updateChatById(localStorage.token, $chatId, { chat = await updateChatById(localStorage.token, $chatId, {
tags: tags.map((tag) => tag.name) tags: tags
}); });
_tags.set(await getAllChatTags(localStorage.token));
}; };
const deleteTag = async (tagName) => { const deleteTag = async (tagName) => {
@ -718,8 +730,10 @@
tags = await getTags(); tags = await getTags();
chat = await updateChatById(localStorage.token, $chatId, { chat = await updateChatById(localStorage.token, $chatId, {
tags: tags.map((tag) => tag.name) tags: tags
}); });
_tags.set(await getAllChatTags(localStorage.token));
}; };
onMount(async () => { onMount(async () => {