From d936353da0f6131d0cf4157f02855902d78cb159 Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 8 Mar 2024 23:19:20 -0800 Subject: [PATCH] fix: message type edge case --- backend/main.py | 41 ++++++++++++++++++++++++---- src/routes/(app)/+page.svelte | 28 +++++++++++++------ src/routes/(app)/c/[id]/+page.svelte | 28 +++++++++++++------ 3 files changed, 73 insertions(+), 24 deletions(-) diff --git a/backend/main.py b/backend/main.py index bb424ae0..11ca81fc 100644 --- a/backend/main.py +++ b/backend/main.py @@ -85,7 +85,24 @@ class RAGMiddleware(BaseHTTPMiddleware): last_user_message_idx = i break - query = data["messages"][last_user_message_idx]["content"] + user_message = data["messages"][last_user_message_idx] + + if isinstance(user_message["content"], list): + # Handle list content input + content_type = "list" + query = "" + for content_item in user_message["content"]: + if content_item["type"] == "text": + query = content_item["text"] + break + elif isinstance(user_message["content"], str): + # Handle text content input + content_type = "text" + query = user_message["content"] + else: + # Fallback in case the input does not match expected types + content_type = None + query = "" relevant_contexts = [] @@ -112,16 +129,28 @@ class RAGMiddleware(BaseHTTPMiddleware): if context: context_string += " ".join(context["documents"][0]) + "\n" - content = rag_template( + ra_content = rag_template( template=rag_app.state.RAG_TEMPLATE, context=context_string, query=query, ) - new_user_message = { - **data["messages"][last_user_message_idx], - "content": content, - } + if content_type == "list": + new_content = [] + for content_item in user_message["content"]: + if content_item["type"] == "text": + # Update the text item's content with ra_content + new_content.append({"type": "text", "text": ra_content}) + else: + # Keep other types of content as they are + new_content.append(content_item) + new_user_message = {**user_message, "content": new_content} + else: + new_user_message = { + **user_message, + "content": ra_content, + } + data["messages"][last_user_message_idx] = new_user_message del data["docs"] diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte index 28bd8eb6..bb3668dc 100644 --- a/src/routes/(app)/+page.svelte +++ b/src/routes/(app)/+page.svelte @@ -295,15 +295,25 @@ ...messages ] .filter((message) => message) - .map((message, idx, arr) => ({ - role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, - ...(message.files && { - images: message.files - .filter((file) => file.type === 'image') - .map((file) => file.url.slice(file.url.indexOf(',') + 1)) - }) - })); + .map((message, idx, arr) => { + // Prepare the base message object + const baseMessage = { + role: message.role, + content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + }; + + // Extract and format image URLs if any exist + const imageUrls = message.files + ?.filter((file) => file.type === 'image') + .map((file) => file.url.slice(file.url.indexOf(',') + 1)); + + // Add images array only if it contains elements + if (imageUrls && imageUrls.length > 0) { + baseMessage.images = imageUrls; + } + + return baseMessage; + }); let lastImageIndex = -1; diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte index 0ec3fae4..4bc6acfa 100644 --- a/src/routes/(app)/c/[id]/+page.svelte +++ b/src/routes/(app)/c/[id]/+page.svelte @@ -308,15 +308,25 @@ ...messages ] .filter((message) => message) - .map((message, idx, arr) => ({ - role: message.role, - content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, - ...(message.files && { - images: message.files - .filter((file) => file.type === 'image') - .map((file) => file.url.slice(file.url.indexOf(',') + 1)) - }) - })); + .map((message, idx, arr) => { + // Prepare the base message object + const baseMessage = { + role: message.role, + content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content + }; + + // Extract and format image URLs if any exist + const imageUrls = message.files + ?.filter((file) => file.type === 'image') + .map((file) => file.url.slice(file.url.indexOf(',') + 1)); + + // Add images array only if it contains elements + if (imageUrls && imageUrls.length > 0) { + baseMessage.images = imageUrls; + } + + return baseMessage; + }); let lastImageIndex = -1;