This commit is contained in:
Timothy J. Baek 2024-03-10 18:40:50 -07:00
parent 88d324b52d
commit 8df6b137cb
2 changed files with 120 additions and 98 deletions

View file

@ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[query\]", query, template)
return template
def rag_messages(docs, messages, template, k, embedding_function):
print(docs)
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
relevant_contexts = []
for doc in docs:
context = None
try:
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
except Exception as e:
print(e)
context = None
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages