From 4e4076e2677979341ee1321338c0da3a55fdefd1 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 11 Nov 2023 20:18:03 -0800 Subject: [PATCH] feat: chat history support --- src/routes/+page.svelte | 522 ++++++++++++++++++++++++++++++---------- 1 file changed, 391 insertions(+), 131 deletions(-) diff --git a/src/routes/+page.svelte b/src/routes/+page.svelte index 4a36bb23..43bc9536 100644 --- a/src/routes/+page.svelte +++ b/src/routes/+page.svelte @@ -39,6 +39,22 @@ let title = ''; let prompt = ''; let messages = []; + let history = { + messages: {}, + currentId: null + }; + + $: if (history.currentId !== null) { + let _messages = []; + + let currentMessage = history.messages[history.currentId]; + while (currentMessage !== null) { + _messages.unshift({ ...currentMessage }); + currentMessage = + currentMessage.parentId !== null ? history.messages[currentMessage.parentId] : null; + } + messages = _messages; + } let showSettings = false; let stopResponseFlag = false; @@ -260,8 +276,13 @@ if (init || messages.length > 0) { chatId = uuidv4(); autoScroll = true; - messages = []; + title = ''; + messages = []; + history = { + messages: {}, + currentId: null + }; settings = JSON.parse(localStorage.getItem('settings') ?? JSON.stringify(settings)); @@ -311,18 +332,58 @@ const loadChat = async (id) => { const chat = await db.get('chats', id); + console.log(chat); if (chatId !== chat.id) { - if (chat.messages.length > 0) { - chat.messages.at(-1).done = true; + if ('history' in chat) { + history = chat.history; + } else { + let _history = { + messages: {}, + currentId: null + }; + + let parentMessageId = null; + let messageId = null; + + for (const message of chat.messages) { + messageId = uuidv4(); + + if (parentMessageId !== null) { + _history.messages[parentMessageId].childrenIds = [ + ..._history.messages[parentMessageId].childrenIds, + messageId + ]; + } + + _history.messages[messageId] = { + ...message, + id: messageId, + parentId: parentMessageId, + childrenIds: [] + }; + + parentMessageId = messageId; + } + _history.currentId = messageId; + + history = _history; } - messages = chat.messages; + + console.log(history); + title = chat.title; chatId = chat.id; selectedModel = chat.model ?? selectedModel; settings.system = chat.system ?? settings.system; settings.temperature = chat.temperature ?? settings.temperature; + autoScroll = true; await tick(); + + if (messages.length > 0) { + history.messages[messages.at(-1).id].done = true; + } + renderLatex(); hljs.highlightAll(); @@ -368,7 +429,8 @@ options: chat.options, title: chat.title, timestamp: chat.timestamp, - messages: chat.messages + messages: chat.messages, + history: chat.history }); } chats = await db.getAllFromIndex('chats', 'timestamp'); @@ -386,35 +448,44 @@ showSettings = true; }; - const editMessage = async (messageIdx) => { - messages = messages.map((message, idx) => { - if (messageIdx === idx) { - message.edit = true; - message.editedContent = message.content; - } - return message; - }); + const editMessageHandler = async (messageId) => { + // let editMessage = history.messages[messageId]; + history.messages[messageId].edit = true; + history.messages[messageId].editedContent = history.messages[messageId].content; }; - const confirmEditMessage = async (messageIdx) => { - let userPrompt = messages.at(messageIdx).editedContent; + const confirmEditMessage = async (messageId) => { + history.messages[messageId].edit = false; - messages.splice(messageIdx, messages.length - messageIdx); - messages = messages; + let userPrompt = history.messages[messageId].editedContent; + let userMessageId = uuidv4(); - await submitPrompt(userPrompt); + let userMessage = { + id: userMessageId, + parentId: history.messages[messageId].parentId, + childrenIds: [], + role: 'user', + content: userPrompt + }; + + let messageParentId = history.messages[messageId].parentId; + + if (messageParentId !== null) { + history.messages[messageParentId].childrenIds = [ + ...history.messages[messageParentId].childrenIds, + userMessageId + ]; + } + + history.messages[userMessageId] = userMessage; + history.currentId = userMessageId; + + await sendPrompt(userPrompt, userMessageId); }; - const cancelEditMessage = (messageIdx) => { - messages = messages.map((message, idx) => { - if (messageIdx === idx) { - message.edit = undefined; - message.editedContent = undefined; - } - return message; - }); - - console.log(messages); + const cancelEditMessage = (messageId) => { + history.messages[messageId].edit = false; + history.messages[messageId].editedContent = undefined; }; const rateMessage = async (messageIdx, rating) => { @@ -434,12 +505,89 @@ temperature: settings.temperature }, timestamp: Date.now(), - messages: messages + messages: messages, + history: history }); console.log(messages); }; + const showPreviousMessage = async (message) => { + if (message.parentId !== null) { + let messageId = + history.messages[message.parentId].childrenIds[ + Math.max(history.messages[message.parentId].childrenIds.indexOf(message.id) - 1, 0) + ]; + + if (message.id !== messageId) { + let messageChildrenIds = history.messages[messageId].childrenIds; + + while (messageChildrenIds.length !== 0) { + messageId = messageChildrenIds.at(-1); + messageChildrenIds = history.messages[messageId].childrenIds; + } + + history.currentId = messageId; + } + } else { + let childrenIds = Object.values(history.messages) + .filter((message) => message.parentId === null) + .map((message) => message.id); + let messageId = childrenIds[Math.max(childrenIds.indexOf(message.id) - 1, 0)]; + + if (message.id !== messageId) { + let messageChildrenIds = history.messages[messageId].childrenIds; + + while (messageChildrenIds.length !== 0) { + messageId = messageChildrenIds.at(-1); + messageChildrenIds = history.messages[messageId].childrenIds; + } + + history.currentId = messageId; + } + } + }; + + const showNextMessage = async (message) => { + if (message.parentId !== null) { + let messageId = + history.messages[message.parentId].childrenIds[ + Math.min( + history.messages[message.parentId].childrenIds.indexOf(message.id) + 1, + history.messages[message.parentId].childrenIds.length - 1 + ) + ]; + + if (message.id !== messageId) { + let messageChildrenIds = history.messages[messageId].childrenIds; + + while (messageChildrenIds.length !== 0) { + messageId = messageChildrenIds.at(-1); + messageChildrenIds = history.messages[messageId].childrenIds; + } + + history.currentId = messageId; + } + } else { + let childrenIds = Object.values(history.messages) + .filter((message) => message.parentId === null) + .map((message) => message.id); + let messageId = + childrenIds[Math.min(childrenIds.indexOf(message.id) + 1, childrenIds.length - 1)]; + + if (message.id !== messageId) { + let messageChildrenIds = history.messages[messageId].childrenIds; + + while (messageChildrenIds.length !== 0) { + messageId = messageChildrenIds.at(-1); + messageChildrenIds = history.messages[messageId].childrenIds; + } + + history.currentId = messageId; + } + } + }; + ////////////////////////// // Ollama functions ////////////////////////// @@ -507,21 +655,46 @@ } }; - const sendPrompt = async (userPrompt) => { + const sendPrompt = async (userPrompt, parentId) => { + // await Promise.all( + // selectedModels.map((model) => { + // if (selectedModel.includes('gpt-')) { + // await sendPromptOpenAI(userPrompt, parentId); + // } else { + // await sendPromptOllama(userPrompt, parentId); + // } + // }) + // ); + if (selectedModel.includes('gpt-')) { - await sendPromptOpenAI(userPrompt); + await sendPromptOpenAI(userPrompt, parentId); } else { - await sendPromptOllama(userPrompt); + await sendPromptOllama(userPrompt, parentId); } + + console.log(history); }; - const sendPromptOllama = async (userPrompt) => { + const sendPromptOllama = async (userPrompt, parentId) => { + let responseMessageId = uuidv4(); + let responseMessage = { + parentId: parentId, + id: responseMessageId, + childrenIds: [], role: 'assistant', content: '' }; - messages = [...messages, responseMessage]; + history.messages[responseMessageId] = responseMessage; + history.currentId = responseMessageId; + if (parentId !== null) { + history.messages[parentId].childrenIds = [ + ...history.messages[parentId].childrenIds, + responseMessageId + ]; + } + window.scrollTo({ top: document.body.scrollHeight }); const res = await fetch(`${API_BASE_URL}/generate`, { @@ -542,8 +715,9 @@ }, format: settings.requestFormat ?? undefined, context: - messages.length > 3 && messages.at(-3).context != undefined - ? messages.at(-3).context + history.messages[parentId] !== null && + history.messages[parentId].parentId in history.messages + ? history.messages[history.messages[parentId].parentId]?.context ?? undefined : undefined }) }); @@ -608,7 +782,8 @@ temperature: settings.temperature }, timestamp: Date.now(), - messages: messages + messages: messages, + history: history }); } @@ -715,7 +890,8 @@ temperature: settings.temperature }, timestamp: Date.now(), - messages: messages + messages: messages, + history: history }); } @@ -747,13 +923,22 @@ } else { document.getElementById('chat-textarea').style.height = ''; - messages = [ - ...messages, - { - role: 'user', - content: userPrompt - } - ]; + let userMessageId = uuidv4(); + + let userMessage = { + id: userMessageId, + parentId: messages.length !== 0 ? messages.at(-1).id : null, + childrenIds: [], + role: 'user', + content: userPrompt + }; + + if (messages.length !== 0) { + history.messages[messages.at(-1).id].childrenIds.push(userMessageId); + } + + history.messages[userMessageId] = userMessage; + history.currentId = userMessageId; prompt = ''; @@ -767,7 +952,8 @@ }, title: 'New Chat', timestamp: Date.now(), - messages: messages + messages: messages, + history: history }); chats = await db.getAllFromIndex('chats', 'timestamp'); } @@ -776,7 +962,7 @@ window.scrollTo({ top: document.body.scrollHeight, behavior: 'smooth' }); }, 50); - await sendPrompt(userPrompt); + await sendPrompt(userPrompt, userMessageId); chats = await db.getAllFromIndex('chats', 'timestamp'); } @@ -791,7 +977,7 @@ let userMessage = messages.at(-1); let userPrompt = userMessage.content; - await sendPrompt(userPrompt); + await sendPrompt(userPrompt, userMessage.id); chats = await db.getAllFromIndex('chats', 'timestamp'); } @@ -1078,7 +1264,7 @@