feat: editable rag template

This commit is contained in:
Timothy J. Baek 2024-02-17 22:41:03 -08:00
parent ccf08fb91e
commit 5270efa9e5
6 changed files with 122 additions and 16 deletions

View file

@ -62,6 +62,7 @@ from config import (
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)
from constants import ERROR_MESSAGES
@ -73,6 +74,8 @@ app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
origins = ["*"]
@ -154,6 +157,25 @@ async def update_chunk_params(
}
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
}
class RAGTemplateForm(BaseModel):
template: str
@app.post("/template/update")
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
# TODO: check template requirements
app.state.RAG_TEMPLATE = form_data.template
return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel):
collection_name: str
query: str

View file

@ -144,6 +144,21 @@ CHROMA_CLIENT = chromadb.PersistentClient(
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
####################################
# Transcribe
####################################

View file

@ -58,6 +58,63 @@ export const updateChunkParams = async (token: string, size: number, overlap: nu
return res;
};
export const getRAGTemplate = async (token: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/template`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateRAGTemplate = async (token: string, template: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
template: template
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
const data = new FormData();
data.append('file', file);

View file

@ -1,5 +1,8 @@
export const RAGTemplate = (context: string, query: string) => {
let template = `Use the following context as your learned knowledge, inside <context></context> XML tags.
import { getRAGTemplate } from '$lib/apis/rag';
export const RAGTemplate = async (token: string, context: string, query: string) => {
let template = await getRAGTemplate(token).catch(() => {
return `Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
@ -12,6 +15,7 @@ export const RAGTemplate = (context: string, query: string) => {
Given the context information, answer the query.
Query: [query]`;
});
template = template.replace(/\[context\]/g, context);
template = template.replace(/\[query\]/g, query);

View file

@ -266,7 +266,11 @@
console.log(contextString);
history.messages[parentId].raContent = RAGTemplate(contextString, query);
history.messages[parentId].raContent = await RAGTemplate(
localStorage.token,
contextString,
query
);
history.messages[parentId].contexts = relevantContexts;
await tick();
processing = '';

View file

@ -280,7 +280,11 @@
console.log(contextString);
history.messages[parentId].raContent = RAGTemplate(contextString, query);
history.messages[parentId].raContent = await RAGTemplate(
localStorage.token,
contextString,
query
);
history.messages[parentId].contexts = relevantContexts;
await tick();
processing = '';