From 6c58bb59bed875752fe3cb90edc499da7bb72957 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 22:43:06 -0800 Subject: [PATCH] feat: rag docs as payload field --- backend/main.py | 2 - src/routes/(app)/+page.svelte | 6 ++- src/routes/(app)/c/[id]/+page.svelte | 70 +++++++++------------------- 3 files changed, 25 insertions(+), 53 deletions(-) diff --git a/backend/main.py b/backend/main.py index cc5edc0f..d36c8420 100644 --- a/backend/main.py +++ b/backend/main.py @@ -123,8 +123,6 @@ class RAGMiddleware(BaseHTTPMiddleware): data["messages"][last_user_message_idx] = new_user_message del data["docs"] - print("DATAAAAAAAAAAAAAAAAAA") - print(data) modified_body_bytes = json.dumps(data).encode("utf-8") # Create a new request with the modified body diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index e5510a06..28bd8eb6 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -336,7 +336,7 @@ }, format: $settings.requestFormat ?? undefined, keep_alive: $settings.keepAlive ?? undefined, - docs: docs + docs: docs.length > 0 ? docs : undefined }); if (res && res.ok) { @@ -503,6 +503,8 @@ ) .flat(1); + console.log(docs); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -552,7 +554,7 @@ num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, max_tokens: $settings?.options?.num_predict ?? undefined, - docs: docs + docs: docs.length > 0 ? docs : undefined }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index dc9f8a58..0ec3fae4 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -245,53 +245,6 @@ const sendPrompt = async (prompt, parentId) => { const _chatId = JSON.parse(JSON.stringify($chatId)); - const docs = messages - .filter((message) => message?.files ?? null) - .map((message) => - message.files.filter((item) => item.type === 'doc' || item.type === 'collection') - ) - .flat(1); - - console.log(docs); - if (docs.length > 0) { - processing = 'Reading'; - const query = history.messages[parentId].content; - - let relevantContexts = await Promise.all( - docs.map(async (doc) => { - if (doc.type === 'collection') { - return await queryCollection(localStorage.token, doc.collection_names, query).catch( - (error) => { - console.log(error); - return null; - } - ); - } else { - return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => { - console.log(error); - return null; - }); - } - }) - ); - relevantContexts = relevantContexts.filter((context) => context); - - const contextString = relevantContexts.reduce((a, context, i, arr) => { - return `${a}${context.documents.join(' ')}\n`; - }, ''); - - console.log(contextString); - - history.messages[parentId].raContent = await RAGTemplate( - localStorage.token, - contextString, - query - ); - history.messages[parentId].contexts = relevantContexts; - await tick(); - processing = ''; - } - await Promise.all( selectedModels.map(async (modelId) => { const model = $models.filter((m) => m.id === modelId).at(0); @@ -381,6 +334,13 @@ } }); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + const [res, controller] = await generateChatCompletion(localStorage.token, { model: model, messages: messagesBody, @@ -388,7 +348,8 @@ ...($settings.options ?? {}) }, format: $settings.requestFormat ?? undefined, - keep_alive: $settings.keepAlive ?? undefined + keep_alive: $settings.keepAlive ?? undefined, + docs: docs.length > 0 ? docs : undefined }); if (res && res.ok) { @@ -548,6 +509,15 @@ const responseMessage = history.messages[responseMessageId]; scrollToBottom(); + const docs = messages + .filter((message) => message?.files ?? null) + .map((message) => + message.files.filter((item) => item.type === 'doc' || item.type === 'collection') + ) + .flat(1); + + console.log(docs); + const res = await generateOpenAIChatCompletion( localStorage.token, { @@ -596,7 +566,8 @@ top_p: $settings?.options?.top_p ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, - max_tokens: $settings?.options?.num_predict ?? undefined + max_tokens: $settings?.options?.num_predict ?? undefined, + docs: docs.length > 0 ? docs : undefined }, model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` ); @@ -710,6 +681,7 @@ await setChatTitle(_chatId, userPrompt); } }; + const stopResponse = () => { stopResponseFlag = true; console.log('stopResponse');