refac: rag to backend

This commit is contained in:
Timothy J. Baek 2024-03-08 22:34:47 -08:00
parent 6ba62cf25d
commit c49491e516
4 changed files with 113 additions and 50 deletions

View file

@ -1,3 +1,4 @@
import re
from typing import List
from config import CHROMA_CLIENT
@ -87,3 +88,10 @@ def query_collection(
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[context\]", context, template)
template = re.sub(r"\[query\]", query, template)
return template

View file

@ -12,6 +12,7 @@ from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app
@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from apps.rag.utils import query_doc, query_collection, rag_template
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
from constants import ERROR_MESSAGES
@ -56,6 +59,89 @@ async def on_startup():
await litellm_app_startup()
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
print(request.url.path)
if request.method == "POST":
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
docs = data["docs"]
print(docs)
last_user_message_idx = None
for i in range(len(data["messages"]) - 1, -1, -1):
if data["messages"][i]["role"] == "user":
last_user_message_idx = i
break
query = data["messages"][last_user_message_idx]["content"]
relevant_contexts = []
for doc in docs:
context = None
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
content = rag_template(
template=rag_app.state.RAG_TEMPLATE,
context=context_string,
query=query,
)
new_user_message = {
**data["messages"][last_user_message_idx],
"content": content,
}
data["messages"][last_user_message_idx] = new_user_message
del data["docs"]
print("DATAAAAAAAAAAAAAAAAAA")
print(data)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Create a new request with the modified body
scope = request.scope
scope["body"] = modified_body_bytes
request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(RAGMiddleware)
@app.middleware("http")
async def check_url(request: Request, call_next):
start_time = int(time.time())

View file

@ -252,7 +252,7 @@ export const queryCollection = async (
token: string,
collection_names: string,
query: string,
k: number
k: number | null = null
) => {
let error = null;

View file

@ -232,53 +232,6 @@
const sendPrompt = async (prompt, parentId) => {
const _chatId = JSON.parse(JSON.stringify($chatId));
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
)
.flat(1);
console.log(docs);
if (docs.length > 0) {
processing = 'Reading';
const query = history.messages[parentId].content;
let relevantContexts = await Promise.all(
docs.map(async (doc) => {
if (doc.type === 'collection') {
return await queryCollection(localStorage.token, doc.collection_names, query).catch(
(error) => {
console.log(error);
return null;
}
);
} else {
return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
console.log(error);
return null;
});
}
})
);
relevantContexts = relevantContexts.filter((context) => context);
const contextString = relevantContexts.reduce((a, context, i, arr) => {
return `${a}${context.documents.join(' ')}\n`;
}, '');
console.log(contextString);
history.messages[parentId].raContent = await RAGTemplate(
localStorage.token,
contextString,
query
);
history.messages[parentId].contexts = relevantContexts;
await tick();
processing = '';
}
await Promise.all(
selectedModels.map(async (modelId) => {
const model = $models.filter((m) => m.id === modelId).at(0);
@ -368,6 +321,13 @@
}
});
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
)
.flat(1);
const [res, controller] = await generateChatCompletion(localStorage.token, {
model: model,
messages: messagesBody,
@ -375,7 +335,8 @@
...($settings.options ?? {})
},
format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined
keep_alive: $settings.keepAlive ?? undefined,
docs: docs
});
if (res && res.ok) {
@ -535,6 +496,13 @@
const responseMessage = history.messages[responseMessageId];
scrollToBottom();
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
)
.flat(1);
const res = await generateOpenAIChatCompletion(
localStorage.token,
{
@ -583,7 +551,8 @@
top_p: $settings?.options?.top_p ?? undefined,
num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined
max_tokens: $settings?.options?.num_predict ?? undefined,
docs: docs
},
model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
);