refac: rag to backend

This commit is contained in:
Timothy J. Baek 2024-03-08 22:34:47 -08:00
parent 6ba62cf25d
commit c49491e516
4 changed files with 113 additions and 50 deletions

View file

@ -1,3 +1,4 @@
import re
from typing import List from typing import List
from config import CHROMA_CLIENT from config import CHROMA_CLIENT
@ -87,3 +88,10 @@ def query_collection(
pass pass
return merge_and_sort_query_results(results, k) return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[context\]", context, template)
template = re.sub(r"\[query\]", query, template)
return template

View file

@ -12,6 +12,7 @@ from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
@ -23,6 +24,8 @@ from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
from apps.rag.utils import query_doc, query_collection, rag_template
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -56,6 +59,89 @@ async def on_startup():
await litellm_app_startup() await litellm_app_startup()
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
print(request.url.path)
if request.method == "POST":
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
docs = data["docs"]
print(docs)
last_user_message_idx = None
for i in range(len(data["messages"]) - 1, -1, -1):
if data["messages"][i]["role"] == "user":
last_user_message_idx = i
break
query = data["messages"][last_user_message_idx]["content"]
relevant_contexts = []
for doc in docs:
context = None
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
content = rag_template(
template=rag_app.state.RAG_TEMPLATE,
context=context_string,
query=query,
)
new_user_message = {
**data["messages"][last_user_message_idx],
"content": content,
}
data["messages"][last_user_message_idx] = new_user_message
del data["docs"]
print("DATAAAAAAAAAAAAAAAAAA")
print(data)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Create a new request with the modified body
scope = request.scope
scope["body"] = modified_body_bytes
request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(RAGMiddleware)
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
start_time = int(time.time()) start_time = int(time.time())

View file

@ -252,7 +252,7 @@ export const queryCollection = async (
token: string, token: string,
collection_names: string, collection_names: string,
query: string, query: string,
k: number k: number | null = null
) => { ) => {
let error = null; let error = null;

View file

@ -232,53 +232,6 @@
const sendPrompt = async (prompt, parentId) => { const sendPrompt = async (prompt, parentId) => {
const _chatId = JSON.parse(JSON.stringify($chatId)); const _chatId = JSON.parse(JSON.stringify($chatId));
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
)
.flat(1);
console.log(docs);
if (docs.length > 0) {
processing = 'Reading';
const query = history.messages[parentId].content;
let relevantContexts = await Promise.all(
docs.map(async (doc) => {
if (doc.type === 'collection') {
return await queryCollection(localStorage.token, doc.collection_names, query).catch(
(error) => {
console.log(error);
return null;
}
);
} else {
return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
console.log(error);
return null;
});
}
})
);
relevantContexts = relevantContexts.filter((context) => context);
const contextString = relevantContexts.reduce((a, context, i, arr) => {
return `${a}${context.documents.join(' ')}\n`;
}, '');
console.log(contextString);
history.messages[parentId].raContent = await RAGTemplate(
localStorage.token,
contextString,
query
);
history.messages[parentId].contexts = relevantContexts;
await tick();
processing = '';
}
await Promise.all( await Promise.all(
selectedModels.map(async (modelId) => { selectedModels.map(async (modelId) => {
const model = $models.filter((m) => m.id === modelId).at(0); const model = $models.filter((m) => m.id === modelId).at(0);
@ -368,6 +321,13 @@
} }
}); });
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
)
.flat(1);
const [res, controller] = await generateChatCompletion(localStorage.token, { const [res, controller] = await generateChatCompletion(localStorage.token, {
model: model, model: model,
messages: messagesBody, messages: messagesBody,
@ -375,7 +335,8 @@
...($settings.options ?? {}) ...($settings.options ?? {})
}, },
format: $settings.requestFormat ?? undefined, format: $settings.requestFormat ?? undefined,
keep_alive: $settings.keepAlive ?? undefined keep_alive: $settings.keepAlive ?? undefined,
docs: docs
}); });
if (res && res.ok) { if (res && res.ok) {
@ -535,6 +496,13 @@
const responseMessage = history.messages[responseMessageId]; const responseMessage = history.messages[responseMessageId];
scrollToBottom(); scrollToBottom();
const docs = messages
.filter((message) => message?.files ?? null)
.map((message) =>
message.files.filter((item) => item.type === 'doc' || item.type === 'collection')
)
.flat(1);
const res = await generateOpenAIChatCompletion( const res = await generateOpenAIChatCompletion(
localStorage.token, localStorage.token,
{ {
@ -583,7 +551,8 @@
top_p: $settings?.options?.top_p ?? undefined, top_p: $settings?.options?.top_p ?? undefined,
num_ctx: $settings?.options?.num_ctx ?? undefined, num_ctx: $settings?.options?.num_ctx ?? undefined,
frequency_penalty: $settings?.options?.repeat_penalty ?? undefined, frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
max_tokens: $settings?.options?.num_predict ?? undefined max_tokens: $settings?.options?.num_predict ?? undefined,
docs: docs
}, },
model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}` model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
); );