refac: rag to backend

This commit is contained in:
Timothy J. Baek 2024-03-08 22:34:47 -08:00
parent 6ba62cf25d
commit c49491e516
4 changed files with 113 additions and 50 deletions

View file

@ -1,3 +1,4 @@
import re
from typing import List
from config import CHROMA_CLIENT
@ -87,3 +88,10 @@ def query_collection(
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[context\]", context, template)
template = re.sub(r"\[query\]", query, template)
return template

View file

@ -12,6 +12,7 @@ from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app
@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from apps.rag.utils import query_doc, query_collection, rag_template
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
from constants import ERROR_MESSAGES
@ -56,6 +59,89 @@ async def on_startup():
await litellm_app_startup()
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
print(request.url.path)
if request.method == "POST":
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
docs = data["docs"]
print(docs)
last_user_message_idx = None
for i in range(len(data["messages"]) - 1, -1, -1):
if data["messages"][i]["role"] == "user":
last_user_message_idx = i
break
query = data["messages"][last_user_message_idx]["content"]
relevant_contexts = []
for doc in docs:
context = None
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
content = rag_template(
template=rag_app.state.RAG_TEMPLATE,
context=context_string,
query=query,
)
new_user_message = {
**data["messages"][last_user_message_idx],
"content": content,
}
data["messages"][last_user_message_idx] = new_user_message
del data["docs"]
print("DATAAAAAAAAAAAAAAAAAA")
print(data)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Create a new request with the modified body
scope = request.scope
scope["body"] = modified_body_bytes
request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(RAGMiddleware)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())