Merge pull request #1130 from open-webui/dev

fix: rag
This commit is contained in:
Timothy Jaeryang Baek 2024-03-10 20:41:58 -05:00 committed by GitHub
commit 11ca2703b0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 120 additions and 98 deletions

View file

@ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[query\]", query, template) template = re.sub(r"\[query\]", query, template)
return template return template
def rag_messages(docs, messages, template, k, embedding_function):
print(docs)
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
relevant_contexts = []
for doc in docs:
context = None
try:
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
except Exception as e:
print(e)
context = None
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages

View file

@ -28,7 +28,7 @@ from typing import List
from utils.utils import get_admin_user from utils.utils import get_admin_user
from apps.rag.utils import query_doc, query_collection, rag_template from apps.rag.utils import rag_messages
from config import ( from config import (
WEBUI_NAME, WEBUI_NAME,
@ -60,19 +60,6 @@ app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
origins = ["*"] origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.on_event("startup")
async def on_startup():
await litellm_app_startup()
class RAGMiddleware(BaseHTTPMiddleware): class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
@ -91,98 +78,33 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Example: Add a new key-value pair or modify existing ones # Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification # data["modified"] = True # Example modification
if "docs" in data: if "docs" in data:
docs = data["docs"]
print(docs)
last_user_message_idx = None data = {**data}
for i in range(len(data["messages"]) - 1, -1, -1): data["messages"] = rag_messages(
if data["messages"][i]["role"] == "user": data["docs"],
last_user_message_idx = i data["messages"],
break rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K,
user_message = data["messages"][last_user_message_idx] rag_app.state.sentence_transformer_ef,
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
relevant_contexts = []
for doc in docs:
context = None
try:
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
) )
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
except Exception as e:
print(e)
context = None
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
ra_content = rag_template(
template=rag_app.state.RAG_TEMPLATE,
context=context_string,
query=query,
)
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
data["messages"][last_user_message_idx] = new_user_message
del data["docs"] del data["docs"]
print(data["messages"]) print(data["messages"])
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Create a new request with the modified body # Replace the request body with the modified one
scope = request.scope request._body = modified_body_bytes
scope["body"] = modified_body_bytes
request = Request(scope, receive=lambda: self._receive(modified_body_bytes)) # Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request) response = await call_next(request)
return response return response
@ -194,6 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
app.add_middleware(RAGMiddleware) app.add_middleware(RAGMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
start_time = int(time.time()) start_time = int(time.time())
@ -204,6 +135,11 @@ async def check_url(request: Request, call_next):
return response return response
@app.on_event("startup")
async def on_startup():
await litellm_app_startup()
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app) app.mount("/litellm/api", litellm_app)