fix: message type edge case

This commit is contained in:
Timothy J. Baek 2024-03-08 23:19:20 -08:00
parent 9f58ed5afa
commit d936353da0
3 changed files with 73 additions and 24 deletions

View file

@ -85,7 +85,24 @@ class RAGMiddleware(BaseHTTPMiddleware):
last_user_message_idx = i last_user_message_idx = i
break break
query = data["messages"][last_user_message_idx]["content"] user_message = data["messages"][last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
relevant_contexts = [] relevant_contexts = []
@ -112,16 +129,28 @@ class RAGMiddleware(BaseHTTPMiddleware):
if context: if context:
context_string += " ".join(context["documents"][0]) + "\n" context_string += " ".join(context["documents"][0]) + "\n"
content = rag_template( ra_content = rag_template(
template=rag_app.state.RAG_TEMPLATE, template=rag_app.state.RAG_TEMPLATE,
context=context_string, context=context_string,
query=query, query=query,
) )
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = { new_user_message = {
**data["messages"][last_user_message_idx], **user_message,
"content": content, "content": ra_content,
} }
data["messages"][last_user_message_idx] = new_user_message data["messages"][last_user_message_idx] = new_user_message
del data["docs"] del data["docs"]

View file

@ -295,15 +295,25 @@
...messages ...messages
] ]
.filter((message) => message) .filter((message) => message)
.map((message, idx, arr) => ({ .map((message, idx, arr) => {
// Prepare the base message object
const baseMessage = {
role: message.role, role: message.role,
content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
...(message.files && { };
images: message.files
.filter((file) => file.type === 'image') // Extract and format image URLs if any exist
.map((file) => file.url.slice(file.url.indexOf(',') + 1)) const imageUrls = message.files
}) ?.filter((file) => file.type === 'image')
})); .map((file) => file.url.slice(file.url.indexOf(',') + 1));
// Add images array only if it contains elements
if (imageUrls && imageUrls.length > 0) {
baseMessage.images = imageUrls;
}
return baseMessage;
});
let lastImageIndex = -1; let lastImageIndex = -1;

View file

@ -308,15 +308,25 @@
...messages ...messages
] ]
.filter((message) => message) .filter((message) => message)
.map((message, idx, arr) => ({ .map((message, idx, arr) => {
// Prepare the base message object
const baseMessage = {
role: message.role, role: message.role,
content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content, content: arr.length - 2 !== idx ? message.content : message?.raContent ?? message.content
...(message.files && { };
images: message.files
.filter((file) => file.type === 'image') // Extract and format image URLs if any exist
.map((file) => file.url.slice(file.url.indexOf(',') + 1)) const imageUrls = message.files
}) ?.filter((file) => file.type === 'image')
})); .map((file) => file.url.slice(file.url.indexOf(',') + 1));
// Add images array only if it contains elements
if (imageUrls && imageUrls.length > 0) {
baseMessage.images = imageUrls;
}
return baseMessage;
});
let lastImageIndex = -1; let lastImageIndex = -1;