forked from open-webui/open-webui
feat: convo tag filtering
This commit is contained in:
parent
1eec176313
commit
220530c450
10 changed files with 214 additions and 27 deletions
|
@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ChatTable:
|
class ChatTable:
|
||||||
|
|
||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
self.db = db
|
self.db = db
|
||||||
db.create_tables([Chat])
|
db.create_tables([Chat])
|
||||||
|
|
||||||
def insert_new_chat(self, user_id: str,
|
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
|
||||||
form_data: ChatForm) -> Optional[ChatModel]:
|
|
||||||
id = str(uuid.uuid4())
|
id = str(uuid.uuid4())
|
||||||
chat = ChatModel(
|
chat = ChatModel(
|
||||||
**{
|
**{
|
||||||
"id": id,
|
"id": id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"title": form_data.chat["title"] if "title" in
|
"title": form_data.chat["title"]
|
||||||
form_data.chat else "New Chat",
|
if "title" in form_data.chat
|
||||||
|
else "New Chat",
|
||||||
"chat": json.dumps(form_data.chat),
|
"chat": json.dumps(form_data.chat),
|
||||||
"timestamp": int(time.time()),
|
"timestamp": int(time.time()),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
result = Chat.create(**chat.model_dump())
|
result = Chat.create(**chat.model_dump())
|
||||||
return chat if result else None
|
return chat if result else None
|
||||||
|
@ -109,25 +109,37 @@ class ChatTable:
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_chat_lists_by_user_id(self,
|
def get_chat_lists_by_user_id(
|
||||||
user_id: str,
|
self, user_id: str, skip: int = 0, limit: int = 50
|
||||||
skip: int = 0,
|
) -> List[ChatModel]:
|
||||||
limit: int = 50) -> List[ChatModel]:
|
|
||||||
return [
|
return [
|
||||||
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
|
ChatModel(**model_to_dict(chat))
|
||||||
Chat.user_id == user_id).order_by(Chat.timestamp.desc())
|
for chat in Chat.select()
|
||||||
|
.where(Chat.user_id == user_id)
|
||||||
|
.order_by(Chat.timestamp.desc())
|
||||||
# .limit(limit)
|
# .limit(limit)
|
||||||
# .offset(skip)
|
# .offset(skip)
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
def get_chat_lists_by_chat_ids(
|
||||||
|
self, chat_ids: List[str], skip: int = 0, limit: int = 50
|
||||||
|
) -> List[ChatModel]:
|
||||||
return [
|
return [
|
||||||
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where(
|
ChatModel(**model_to_dict(chat))
|
||||||
Chat.user_id == user_id).order_by(Chat.timestamp.desc())
|
for chat in Chat.select()
|
||||||
|
.where(Chat.id.in_(chat_ids))
|
||||||
|
.order_by(Chat.timestamp.desc())
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_chat_by_id_and_user_id(self, id: str,
|
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||||
user_id: str) -> Optional[ChatModel]:
|
return [
|
||||||
|
ChatModel(**model_to_dict(chat))
|
||||||
|
for chat in Chat.select()
|
||||||
|
.where(Chat.user_id == user_id)
|
||||||
|
.order_by(Chat.timestamp.desc())
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
|
||||||
try:
|
try:
|
||||||
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
|
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
|
||||||
return ChatModel(**model_to_dict(chat))
|
return ChatModel(**model_to_dict(chat))
|
||||||
|
@ -142,8 +154,7 @@ class ChatTable:
|
||||||
|
|
||||||
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
|
||||||
try:
|
try:
|
||||||
query = Chat.delete().where((Chat.id == id)
|
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
|
||||||
& (Chat.user_id == user_id))
|
|
||||||
query.execute() # Remove the rows, return number of rows removed.
|
query.execute() # Remove the rows, return number of rows removed.
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -120,6 +120,19 @@ class TagTable:
|
||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
|
||||||
|
tag_names = [
|
||||||
|
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
|
||||||
|
for chat_id_tag in ChatIdTag.select()
|
||||||
|
.where(ChatIdTag.user_id == user_id)
|
||||||
|
.order_by(ChatIdTag.timestamp.desc())
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
TagModel(**model_to_dict(tag))
|
||||||
|
for tag in Tag.select().where(Tag.name.in_(tag_names))
|
||||||
|
]
|
||||||
|
|
||||||
def get_tags_by_chat_id_and_user_id(
|
def get_tags_by_chat_id_and_user_id(
|
||||||
self, chat_id: str, user_id: str
|
self, chat_id: str, user_id: str
|
||||||
) -> List[TagModel]:
|
) -> List[TagModel]:
|
||||||
|
|
|
@ -74,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetAllTags
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tags/all", response_model=List[TagModel])
|
||||||
|
async def get_all_tags(user=Depends(get_current_user)):
|
||||||
|
try:
|
||||||
|
tags = Tags.get_tags_by_user_id(user.id)
|
||||||
|
return tags
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
############################
|
||||||
|
# GetChatsByTags
|
||||||
|
############################
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
|
||||||
|
async def get_user_chats_by_tag_name(
|
||||||
|
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
|
||||||
|
):
|
||||||
|
chat_ids = [
|
||||||
|
chat_id_tag.chat_id
|
||||||
|
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
|
||||||
|
]
|
||||||
|
|
||||||
|
print(chat_ids)
|
||||||
|
|
||||||
|
return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
|
||||||
|
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# GetChatById
|
# GetChatById
|
||||||
############################
|
############################
|
||||||
|
|
|
@ -93,6 +93,68 @@ export const getAllChats = async (token: string) => {
|
||||||
return res;
|
return res;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const getAllChatTags = async (token: string) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
...(token && { authorization: `Bearer ${token}` })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.then((json) => {
|
||||||
|
return json;
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
error = err;
|
||||||
|
console.log(err);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const getChatListByTagName = async (token: string = '', tagName: string) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
Accept: 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
...(token && { authorization: `Bearer ${token}` })
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.then((json) => {
|
||||||
|
return json;
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
error = err;
|
||||||
|
console.log(err);
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
export const getChatById = async (token: string, id: string) => {
|
export const getChatById = async (token: string, id: string) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@
|
||||||
import { getChatById } from '$lib/apis/chats';
|
import { getChatById } from '$lib/apis/chats';
|
||||||
import { chatId, modelfiles } from '$lib/stores';
|
import { chatId, modelfiles } from '$lib/stores';
|
||||||
import ShareChatModal from '../chat/ShareChatModal.svelte';
|
import ShareChatModal from '../chat/ShareChatModal.svelte';
|
||||||
import { stringify } from 'postcss';
|
|
||||||
|
|
||||||
export let initNewChat: Function;
|
export let initNewChat: Function;
|
||||||
export let title: string = 'Ollama Web UI';
|
export let title: string = 'Ollama Web UI';
|
||||||
|
|
|
@ -6,9 +6,14 @@
|
||||||
|
|
||||||
import { goto, invalidateAll } from '$app/navigation';
|
import { goto, invalidateAll } from '$app/navigation';
|
||||||
import { page } from '$app/stores';
|
import { page } from '$app/stores';
|
||||||
import { user, chats, settings, showSettings, chatId } from '$lib/stores';
|
import { user, chats, settings, showSettings, chatId, tags } from '$lib/stores';
|
||||||
import { onMount } from 'svelte';
|
import { onMount } from 'svelte';
|
||||||
import { deleteChatById, getChatList, updateChatById } from '$lib/apis/chats';
|
import {
|
||||||
|
deleteChatById,
|
||||||
|
getChatList,
|
||||||
|
getChatListByTagName,
|
||||||
|
updateChatById
|
||||||
|
} from '$lib/apis/chats';
|
||||||
|
|
||||||
let show = false;
|
let show = false;
|
||||||
let navElement;
|
let navElement;
|
||||||
|
@ -28,6 +33,12 @@
|
||||||
}
|
}
|
||||||
|
|
||||||
await chats.set(await getChatList(localStorage.token));
|
await chats.set(await getChatList(localStorage.token));
|
||||||
|
|
||||||
|
tags.subscribe(async (value) => {
|
||||||
|
if (value.length === 0) {
|
||||||
|
await chats.set(await getChatList(localStorage.token));
|
||||||
|
}
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
const loadChat = async (id) => {
|
const loadChat = async (id) => {
|
||||||
|
@ -281,6 +292,29 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{#if $tags.length > 0}
|
||||||
|
<div class="px-2.5 mt-0.5 mb-2 flex gap-1 flex-wrap">
|
||||||
|
<button
|
||||||
|
class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full"
|
||||||
|
on:click={async () => {
|
||||||
|
await chats.set(await getChatList(localStorage.token));
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
all
|
||||||
|
</button>
|
||||||
|
{#each $tags as tag}
|
||||||
|
<button
|
||||||
|
class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full"
|
||||||
|
on:click={async () => {
|
||||||
|
await chats.set(await getChatListByTagName(localStorage.token, tag.name));
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{tag.name}
|
||||||
|
</button>
|
||||||
|
{/each}
|
||||||
|
</div>
|
||||||
|
{/if}
|
||||||
|
|
||||||
<div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto">
|
<div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto">
|
||||||
{#each $chats.filter((chat) => {
|
{#each $chats.filter((chat) => {
|
||||||
if (search === '') {
|
if (search === '') {
|
||||||
|
|
|
@ -10,6 +10,7 @@ export const theme = writable('dark');
|
||||||
export const chatId = writable('');
|
export const chatId = writable('');
|
||||||
|
|
||||||
export const chats = writable([]);
|
export const chats = writable([]);
|
||||||
|
export const tags = writable([]);
|
||||||
export const models = writable([]);
|
export const models = writable([]);
|
||||||
|
|
||||||
export const modelfiles = writable([]);
|
export const modelfiles = writable([]);
|
||||||
|
|
|
@ -20,7 +20,8 @@
|
||||||
models,
|
models,
|
||||||
modelfiles,
|
modelfiles,
|
||||||
prompts,
|
prompts,
|
||||||
documents
|
documents,
|
||||||
|
tags
|
||||||
} from '$lib/stores';
|
} from '$lib/stores';
|
||||||
import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
|
import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
|
||||||
|
|
||||||
|
@ -29,6 +30,7 @@
|
||||||
import { checkVersion } from '$lib/utils';
|
import { checkVersion } from '$lib/utils';
|
||||||
import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
|
import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
|
||||||
import { getDocs } from '$lib/apis/documents';
|
import { getDocs } from '$lib/apis/documents';
|
||||||
|
import { getAllChatTags } from '$lib/apis/chats';
|
||||||
|
|
||||||
let ollamaVersion = '';
|
let ollamaVersion = '';
|
||||||
let loaded = false;
|
let loaded = false;
|
||||||
|
@ -106,6 +108,7 @@
|
||||||
await modelfiles.set(await getModelfiles(localStorage.token));
|
await modelfiles.set(await getModelfiles(localStorage.token));
|
||||||
await prompts.set(await getPrompts(localStorage.token));
|
await prompts.set(await getPrompts(localStorage.token));
|
||||||
await documents.set(await getDocs(localStorage.token));
|
await documents.set(await getDocs(localStorage.token));
|
||||||
|
await tags.set(await getAllChatTags(localStorage.token));
|
||||||
|
|
||||||
modelfiles.subscribe(async () => {
|
modelfiles.subscribe(async () => {
|
||||||
// should fetch models
|
// should fetch models
|
||||||
|
|
|
@ -6,7 +6,16 @@
|
||||||
import { goto } from '$app/navigation';
|
import { goto } from '$app/navigation';
|
||||||
import { page } from '$app/stores';
|
import { page } from '$app/stores';
|
||||||
|
|
||||||
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
|
import {
|
||||||
|
models,
|
||||||
|
modelfiles,
|
||||||
|
user,
|
||||||
|
settings,
|
||||||
|
chats,
|
||||||
|
chatId,
|
||||||
|
config,
|
||||||
|
tags as _tags
|
||||||
|
} from '$lib/stores';
|
||||||
import { copyToClipboard, splitStream } from '$lib/utils';
|
import { copyToClipboard, splitStream } from '$lib/utils';
|
||||||
|
|
||||||
import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
|
import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
|
||||||
|
@ -14,6 +23,7 @@
|
||||||
addTagById,
|
addTagById,
|
||||||
createNewChat,
|
createNewChat,
|
||||||
deleteTagById,
|
deleteTagById,
|
||||||
|
getAllChatTags,
|
||||||
getChatList,
|
getChatList,
|
||||||
getTagsById,
|
getTagsById,
|
||||||
updateChatById
|
updateChatById
|
||||||
|
@ -695,6 +705,8 @@
|
||||||
chat = await updateChatById(localStorage.token, $chatId, {
|
chat = await updateChatById(localStorage.token, $chatId, {
|
||||||
tags: tags
|
tags: tags
|
||||||
});
|
});
|
||||||
|
|
||||||
|
_tags.set(await getAllChatTags(localStorage.token));
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteTag = async (tagName) => {
|
const deleteTag = async (tagName) => {
|
||||||
|
@ -704,6 +716,8 @@
|
||||||
chat = await updateChatById(localStorage.token, $chatId, {
|
chat = await updateChatById(localStorage.token, $chatId, {
|
||||||
tags: tags
|
tags: tags
|
||||||
});
|
});
|
||||||
|
|
||||||
|
_tags.set(await getAllChatTags(localStorage.token));
|
||||||
};
|
};
|
||||||
|
|
||||||
const setChatTitle = async (_chatId, _title) => {
|
const setChatTitle = async (_chatId, _title) => {
|
||||||
|
|
|
@ -6,7 +6,16 @@
|
||||||
import { goto } from '$app/navigation';
|
import { goto } from '$app/navigation';
|
||||||
import { page } from '$app/stores';
|
import { page } from '$app/stores';
|
||||||
|
|
||||||
import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores';
|
import {
|
||||||
|
models,
|
||||||
|
modelfiles,
|
||||||
|
user,
|
||||||
|
settings,
|
||||||
|
chats,
|
||||||
|
chatId,
|
||||||
|
config,
|
||||||
|
tags as _tags
|
||||||
|
} from '$lib/stores';
|
||||||
import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
|
import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
|
||||||
|
|
||||||
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
|
import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
|
||||||
|
@ -14,6 +23,7 @@
|
||||||
addTagById,
|
addTagById,
|
||||||
createNewChat,
|
createNewChat,
|
||||||
deleteTagById,
|
deleteTagById,
|
||||||
|
getAllChatTags,
|
||||||
getChatById,
|
getChatById,
|
||||||
getChatList,
|
getChatList,
|
||||||
getTagsById,
|
getTagsById,
|
||||||
|
@ -709,8 +719,10 @@
|
||||||
tags = await getTags();
|
tags = await getTags();
|
||||||
|
|
||||||
chat = await updateChatById(localStorage.token, $chatId, {
|
chat = await updateChatById(localStorage.token, $chatId, {
|
||||||
tags: tags.map((tag) => tag.name)
|
tags: tags
|
||||||
});
|
});
|
||||||
|
|
||||||
|
_tags.set(await getAllChatTags(localStorage.token));
|
||||||
};
|
};
|
||||||
|
|
||||||
const deleteTag = async (tagName) => {
|
const deleteTag = async (tagName) => {
|
||||||
|
@ -718,8 +730,10 @@
|
||||||
tags = await getTags();
|
tags = await getTags();
|
||||||
|
|
||||||
chat = await updateChatById(localStorage.token, $chatId, {
|
chat = await updateChatById(localStorage.token, $chatId, {
|
||||||
tags: tags.map((tag) => tag.name)
|
tags: tags
|
||||||
});
|
});
|
||||||
|
|
||||||
|
_tags.set(await getAllChatTags(localStorage.token));
|
||||||
};
|
};
|
||||||
|
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
|
|
Loading…
Reference in a new issue