forked from open-webui/open-webui
		
	feat: convo tag filtering
This commit is contained in:
		
							parent
							
								
									1eec176313
								
							
						
					
					
						commit
						220530c450
					
				
					 10 changed files with 214 additions and 27 deletions
				
			
		|  | @ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel): | |||
| 
 | ||||
| 
 | ||||
| class ChatTable: | ||||
| 
 | ||||
|     def __init__(self, db): | ||||
|         self.db = db | ||||
|         db.create_tables([Chat]) | ||||
| 
 | ||||
|     def insert_new_chat(self, user_id: str, | ||||
|                         form_data: ChatForm) -> Optional[ChatModel]: | ||||
|     def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]: | ||||
|         id = str(uuid.uuid4()) | ||||
|         chat = ChatModel( | ||||
|             **{ | ||||
|                 "id": id, | ||||
|                 "user_id": user_id, | ||||
|                 "title": form_data.chat["title"] if "title" in | ||||
|                 form_data.chat else "New Chat", | ||||
|                 "title": form_data.chat["title"] | ||||
|                 if "title" in form_data.chat | ||||
|                 else "New Chat", | ||||
|                 "chat": json.dumps(form_data.chat), | ||||
|                 "timestamp": int(time.time()), | ||||
|             }) | ||||
|             } | ||||
|         ) | ||||
| 
 | ||||
|         result = Chat.create(**chat.model_dump()) | ||||
|         return chat if result else None | ||||
|  | @ -109,25 +109,37 @@ class ChatTable: | |||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_chat_lists_by_user_id(self, | ||||
|                                   user_id: str, | ||||
|                                   skip: int = 0, | ||||
|                                   limit: int = 50) -> List[ChatModel]: | ||||
|     def get_chat_lists_by_user_id( | ||||
|         self, user_id: str, skip: int = 0, limit: int = 50 | ||||
|     ) -> List[ChatModel]: | ||||
|         return [ | ||||
|             ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( | ||||
|                 Chat.user_id == user_id).order_by(Chat.timestamp.desc()) | ||||
|             ChatModel(**model_to_dict(chat)) | ||||
|             for chat in Chat.select() | ||||
|             .where(Chat.user_id == user_id) | ||||
|             .order_by(Chat.timestamp.desc()) | ||||
|             # .limit(limit) | ||||
|             # .offset(skip) | ||||
|         ] | ||||
| 
 | ||||
|     def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: | ||||
|     def get_chat_lists_by_chat_ids( | ||||
|         self, chat_ids: List[str], skip: int = 0, limit: int = 50 | ||||
|     ) -> List[ChatModel]: | ||||
|         return [ | ||||
|             ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( | ||||
|                 Chat.user_id == user_id).order_by(Chat.timestamp.desc()) | ||||
|             ChatModel(**model_to_dict(chat)) | ||||
|             for chat in Chat.select() | ||||
|             .where(Chat.id.in_(chat_ids)) | ||||
|             .order_by(Chat.timestamp.desc()) | ||||
|         ] | ||||
| 
 | ||||
|     def get_chat_by_id_and_user_id(self, id: str, | ||||
|                                    user_id: str) -> Optional[ChatModel]: | ||||
|     def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: | ||||
|         return [ | ||||
|             ChatModel(**model_to_dict(chat)) | ||||
|             for chat in Chat.select() | ||||
|             .where(Chat.user_id == user_id) | ||||
|             .order_by(Chat.timestamp.desc()) | ||||
|         ] | ||||
| 
 | ||||
|     def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: | ||||
|         try: | ||||
|             chat = Chat.get(Chat.id == id, Chat.user_id == user_id) | ||||
|             return ChatModel(**model_to_dict(chat)) | ||||
|  | @ -142,8 +154,7 @@ class ChatTable: | |||
| 
 | ||||
|     def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: | ||||
|         try: | ||||
|             query = Chat.delete().where((Chat.id == id) | ||||
|                                         & (Chat.user_id == user_id)) | ||||
|             query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) | ||||
|             query.execute()  # Remove the rows, return number of rows removed. | ||||
| 
 | ||||
|             return True | ||||
|  |  | |||
|  | @ -120,6 +120,19 @@ class TagTable: | |||
|         except: | ||||
|             return None | ||||
| 
 | ||||
|     def get_tags_by_user_id(self, user_id: str) -> List[TagModel]: | ||||
|         tag_names = [ | ||||
|             ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name | ||||
|             for chat_id_tag in ChatIdTag.select() | ||||
|             .where(ChatIdTag.user_id == user_id) | ||||
|             .order_by(ChatIdTag.timestamp.desc()) | ||||
|         ] | ||||
| 
 | ||||
|         return [ | ||||
|             TagModel(**model_to_dict(tag)) | ||||
|             for tag in Tag.select().where(Tag.name.in_(tag_names)) | ||||
|         ] | ||||
| 
 | ||||
|     def get_tags_by_chat_id_and_user_id( | ||||
|         self, chat_id: str, user_id: str | ||||
|     ) -> List[TagModel]: | ||||
|  |  | |||
|  | @ -74,6 +74,42 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): | |||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # GetAllTags | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/tags/all", response_model=List[TagModel]) | ||||
| async def get_all_tags(user=Depends(get_current_user)): | ||||
|     try: | ||||
|         tags = Tags.get_tags_by_user_id(user.id) | ||||
|         return tags | ||||
|     except Exception as e: | ||||
|         print(e) | ||||
|         raise HTTPException( | ||||
|             status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT() | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # GetChatsByTags | ||||
| ############################ | ||||
| 
 | ||||
| 
 | ||||
| @router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse]) | ||||
| async def get_user_chats_by_tag_name( | ||||
|     tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50 | ||||
| ): | ||||
|     chat_ids = [ | ||||
|         chat_id_tag.chat_id | ||||
|         for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id) | ||||
|     ] | ||||
| 
 | ||||
|     print(chat_ids) | ||||
| 
 | ||||
|     return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit) | ||||
| 
 | ||||
| 
 | ||||
| ############################ | ||||
| # GetChatById | ||||
| ############################ | ||||
|  |  | |||
|  | @ -93,6 +93,68 @@ export const getAllChats = async (token: string) => { | |||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const getAllChatTags = async (token: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/all`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			Accept: 'application/json', | ||||
| 			'Content-Type': 'application/json', | ||||
| 			...(token && { authorization: `Bearer ${token}` }) | ||||
| 		} | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.then((json) => { | ||||
| 			return json; | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			error = err; | ||||
| 			console.log(err); | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const getChatListByTagName = async (token: string = '', tagName: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
| 	const res = await fetch(`${WEBUI_API_BASE_URL}/chats/tags/tag/${tagName}`, { | ||||
| 		method: 'GET', | ||||
| 		headers: { | ||||
| 			Accept: 'application/json', | ||||
| 			'Content-Type': 'application/json', | ||||
| 			...(token && { authorization: `Bearer ${token}` }) | ||||
| 		} | ||||
| 	}) | ||||
| 		.then(async (res) => { | ||||
| 			if (!res.ok) throw await res.json(); | ||||
| 			return res.json(); | ||||
| 		}) | ||||
| 		.then((json) => { | ||||
| 			return json; | ||||
| 		}) | ||||
| 		.catch((err) => { | ||||
| 			error = err; | ||||
| 			console.log(err); | ||||
| 			return null; | ||||
| 		}); | ||||
| 
 | ||||
| 	if (error) { | ||||
| 		throw error; | ||||
| 	} | ||||
| 
 | ||||
| 	return res; | ||||
| }; | ||||
| 
 | ||||
| export const getChatById = async (token: string, id: string) => { | ||||
| 	let error = null; | ||||
| 
 | ||||
|  |  | |||
|  | @ -6,7 +6,6 @@ | |||
| 	import { getChatById } from '$lib/apis/chats'; | ||||
| 	import { chatId, modelfiles } from '$lib/stores'; | ||||
| 	import ShareChatModal from '../chat/ShareChatModal.svelte'; | ||||
| 	import { stringify } from 'postcss'; | ||||
| 
 | ||||
| 	export let initNewChat: Function; | ||||
| 	export let title: string = 'Ollama Web UI'; | ||||
|  |  | |||
|  | @ -6,9 +6,14 @@ | |||
| 
 | ||||
| 	import { goto, invalidateAll } from '$app/navigation'; | ||||
| 	import { page } from '$app/stores'; | ||||
| 	import { user, chats, settings, showSettings, chatId } from '$lib/stores'; | ||||
| 	import { user, chats, settings, showSettings, chatId, tags } from '$lib/stores'; | ||||
| 	import { onMount } from 'svelte'; | ||||
| 	import { deleteChatById, getChatList, updateChatById } from '$lib/apis/chats'; | ||||
| 	import { | ||||
| 		deleteChatById, | ||||
| 		getChatList, | ||||
| 		getChatListByTagName, | ||||
| 		updateChatById | ||||
| 	} from '$lib/apis/chats'; | ||||
| 
 | ||||
| 	let show = false; | ||||
| 	let navElement; | ||||
|  | @ -28,6 +33,12 @@ | |||
| 		} | ||||
| 
 | ||||
| 		await chats.set(await getChatList(localStorage.token)); | ||||
| 
 | ||||
| 		tags.subscribe(async (value) => { | ||||
| 			if (value.length === 0) { | ||||
| 				await chats.set(await getChatList(localStorage.token)); | ||||
| 			} | ||||
| 		}); | ||||
| 	}); | ||||
| 
 | ||||
| 	const loadChat = async (id) => { | ||||
|  | @ -281,6 +292,29 @@ | |||
| 				</div> | ||||
| 			</div> | ||||
| 
 | ||||
| 			{#if $tags.length > 0} | ||||
| 				<div class="px-2.5 mt-0.5 mb-2 flex gap-1 flex-wrap"> | ||||
| 					<button | ||||
| 						class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full" | ||||
| 						on:click={async () => { | ||||
| 							await chats.set(await getChatList(localStorage.token)); | ||||
| 						}} | ||||
| 					> | ||||
| 						all | ||||
| 					</button> | ||||
| 					{#each $tags as tag} | ||||
| 						<button | ||||
| 							class="px-2.5 text-xs font-medium bg-gray-900 hover:bg-gray-800 transition rounded-full" | ||||
| 							on:click={async () => { | ||||
| 								await chats.set(await getChatListByTagName(localStorage.token, tag.name)); | ||||
| 							}} | ||||
| 						> | ||||
| 							{tag.name} | ||||
| 						</button> | ||||
| 					{/each} | ||||
| 				</div> | ||||
| 			{/if} | ||||
| 
 | ||||
| 			<div class="pl-2.5 my-2 flex-1 flex flex-col space-y-1 overflow-y-auto"> | ||||
| 				{#each $chats.filter((chat) => { | ||||
| 					if (search === '') { | ||||
|  |  | |||
|  | @ -10,6 +10,7 @@ export const theme = writable('dark'); | |||
| export const chatId = writable(''); | ||||
| 
 | ||||
| export const chats = writable([]); | ||||
| export const tags = writable([]); | ||||
| export const models = writable([]); | ||||
| 
 | ||||
| export const modelfiles = writable([]); | ||||
|  |  | |||
|  | @ -20,7 +20,8 @@ | |||
| 		models, | ||||
| 		modelfiles, | ||||
| 		prompts, | ||||
| 		documents | ||||
| 		documents, | ||||
| 		tags | ||||
| 	} from '$lib/stores'; | ||||
| 	import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants'; | ||||
| 
 | ||||
|  | @ -29,6 +30,7 @@ | |||
| 	import { checkVersion } from '$lib/utils'; | ||||
| 	import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte'; | ||||
| 	import { getDocs } from '$lib/apis/documents'; | ||||
| 	import { getAllChatTags } from '$lib/apis/chats'; | ||||
| 
 | ||||
| 	let ollamaVersion = ''; | ||||
| 	let loaded = false; | ||||
|  | @ -106,6 +108,7 @@ | |||
| 			await modelfiles.set(await getModelfiles(localStorage.token)); | ||||
| 			await prompts.set(await getPrompts(localStorage.token)); | ||||
| 			await documents.set(await getDocs(localStorage.token)); | ||||
| 			await tags.set(await getAllChatTags(localStorage.token)); | ||||
| 
 | ||||
| 			modelfiles.subscribe(async () => { | ||||
| 				// should fetch models | ||||
|  |  | |||
|  | @ -6,7 +6,16 @@ | |||
| 	import { goto } from '$app/navigation'; | ||||
| 	import { page } from '$app/stores'; | ||||
| 
 | ||||
| 	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; | ||||
| 	import { | ||||
| 		models, | ||||
| 		modelfiles, | ||||
| 		user, | ||||
| 		settings, | ||||
| 		chats, | ||||
| 		chatId, | ||||
| 		config, | ||||
| 		tags as _tags | ||||
| 	} from '$lib/stores'; | ||||
| 	import { copyToClipboard, splitStream } from '$lib/utils'; | ||||
| 
 | ||||
| 	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama'; | ||||
|  | @ -14,6 +23,7 @@ | |||
| 		addTagById, | ||||
| 		createNewChat, | ||||
| 		deleteTagById, | ||||
| 		getAllChatTags, | ||||
| 		getChatList, | ||||
| 		getTagsById, | ||||
| 		updateChatById | ||||
|  | @ -695,6 +705,8 @@ | |||
| 		chat = await updateChatById(localStorage.token, $chatId, { | ||||
| 			tags: tags | ||||
| 		}); | ||||
| 
 | ||||
| 		_tags.set(await getAllChatTags(localStorage.token)); | ||||
| 	}; | ||||
| 
 | ||||
| 	const deleteTag = async (tagName) => { | ||||
|  | @ -704,6 +716,8 @@ | |||
| 		chat = await updateChatById(localStorage.token, $chatId, { | ||||
| 			tags: tags | ||||
| 		}); | ||||
| 
 | ||||
| 		_tags.set(await getAllChatTags(localStorage.token)); | ||||
| 	}; | ||||
| 
 | ||||
| 	const setChatTitle = async (_chatId, _title) => { | ||||
|  |  | |||
|  | @ -6,7 +6,16 @@ | |||
| 	import { goto } from '$app/navigation'; | ||||
| 	import { page } from '$app/stores'; | ||||
| 
 | ||||
| 	import { models, modelfiles, user, settings, chats, chatId, config } from '$lib/stores'; | ||||
| 	import { | ||||
| 		models, | ||||
| 		modelfiles, | ||||
| 		user, | ||||
| 		settings, | ||||
| 		chats, | ||||
| 		chatId, | ||||
| 		config, | ||||
| 		tags as _tags | ||||
| 	} from '$lib/stores'; | ||||
| 	import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils'; | ||||
| 
 | ||||
| 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama'; | ||||
|  | @ -14,6 +23,7 @@ | |||
| 		addTagById, | ||||
| 		createNewChat, | ||||
| 		deleteTagById, | ||||
| 		getAllChatTags, | ||||
| 		getChatById, | ||||
| 		getChatList, | ||||
| 		getTagsById, | ||||
|  | @ -709,8 +719,10 @@ | |||
| 		tags = await getTags(); | ||||
| 
 | ||||
| 		chat = await updateChatById(localStorage.token, $chatId, { | ||||
| 			tags: tags.map((tag) => tag.name) | ||||
| 			tags: tags | ||||
| 		}); | ||||
| 
 | ||||
| 		_tags.set(await getAllChatTags(localStorage.token)); | ||||
| 	}; | ||||
| 
 | ||||
| 	const deleteTag = async (tagName) => { | ||||
|  | @ -718,8 +730,10 @@ | |||
| 		tags = await getTags(); | ||||
| 
 | ||||
| 		chat = await updateChatById(localStorage.token, $chatId, { | ||||
| 			tags: tags.map((tag) => tag.name) | ||||
| 			tags: tags | ||||
| 		}); | ||||
| 
 | ||||
| 		_tags.set(await getAllChatTags(localStorage.token)); | ||||
| 	}; | ||||
| 
 | ||||
| 	onMount(async () => { | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Timothy J. Baek
						Timothy J. Baek