diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py
index f24b6b90..6d06456f 100644
--- a/backend/apps/rag/main.py
+++ b/backend/apps/rag/main.py
@@ -62,6 +62,7 @@ from config import (
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
+ RAG_TEMPLATE,
)
from constants import ERROR_MESSAGES
@@ -71,6 +72,11 @@ from constants import ERROR_MESSAGES
app = FastAPI()
+app.state.CHUNK_SIZE = CHUNK_SIZE
+app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
+app.state.RAG_TEMPLATE = RAG_TEMPLATE
+
+
origins = ["*"]
app.add_middleware(
@@ -92,7 +98,7 @@ class StoreWebForm(CollectionNameForm):
def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
- chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
+ chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
@@ -116,7 +122,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
@app.get("/")
async def get_status():
- return {"status": True}
+ return {
+ "status": True,
+ "chunk_size": app.state.CHUNK_SIZE,
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
+ }
+
+
+@app.get("/chunk")
+async def get_chunk_params(user=Depends(get_admin_user)):
+ return {
+ "status": True,
+ "chunk_size": app.state.CHUNK_SIZE,
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
+ }
+
+
+class ChunkParamUpdateForm(BaseModel):
+ chunk_size: int
+ chunk_overlap: int
+
+
+@app.post("/chunk/update")
+async def update_chunk_params(
+ form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
+):
+ app.state.CHUNK_SIZE = form_data.chunk_size
+ app.state.CHUNK_OVERLAP = form_data.chunk_overlap
+
+ return {
+ "status": True,
+ "chunk_size": app.state.CHUNK_SIZE,
+ "chunk_overlap": app.state.CHUNK_OVERLAP,
+ }
+
+
+@app.get("/template")
+async def get_rag_template(user=Depends(get_current_user)):
+ return {
+ "status": True,
+ "template": app.state.RAG_TEMPLATE,
+ }
+
+
+class RAGTemplateForm(BaseModel):
+ template: str
+
+
+@app.post("/template/update")
+async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
+ # TODO: check template requirements
+ app.state.RAG_TEMPLATE = (
+ form_data.template if form_data.template != "" else RAG_TEMPLATE
+ )
+ return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel):
diff --git a/backend/config.py b/backend/config.py
index f5acf06b..440256c4 100644
--- a/backend/config.py
+++ b/backend/config.py
@@ -144,6 +144,21 @@ CHROMA_CLIENT = chromadb.PersistentClient(
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
+
+RAG_TEMPLATE = """Use the following context as your learned knowledge, inside XML tags.
+
+ [context]
+
+
+When answer to user:
+- If you don't know, just say that you don't know.
+- If you don't know when you are not sure, ask for clarification.
+Avoid mentioning that you obtained the information from the context.
+And answer according to the language of the user's question.
+
+Given the context information, answer the query.
+Query: [query]"""
+
####################################
# Transcribe
####################################
diff --git a/src/lib/apis/rag/index.ts b/src/lib/apis/rag/index.ts
index fc3571aa..ed36f014 100644
--- a/src/lib/apis/rag/index.ts
+++ b/src/lib/apis/rag/index.ts
@@ -1,5 +1,120 @@
import { RAG_API_BASE_URL } from '$lib/constants';
+export const getChunkParams = async (token: string) => {
+ let error = null;
+
+ const res = await fetch(`${RAG_API_BASE_URL}/chunk`, {
+ method: 'GET',
+ headers: {
+ 'Content-Type': 'application/json',
+ Authorization: `Bearer ${token}`
+ }
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
+export const updateChunkParams = async (token: string, size: number, overlap: number) => {
+ let error = null;
+
+ const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ Authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ chunk_size: size,
+ chunk_overlap: overlap
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
+export const getRAGTemplate = async (token: string) => {
+ let error = null;
+
+ const res = await fetch(`${RAG_API_BASE_URL}/template`, {
+ method: 'GET',
+ headers: {
+ 'Content-Type': 'application/json',
+ Authorization: `Bearer ${token}`
+ }
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res?.template ?? '';
+};
+
+export const updateRAGTemplate = async (token: string, template: string) => {
+ let error = null;
+
+ const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
+ method: 'POST',
+ headers: {
+ 'Content-Type': 'application/json',
+ Authorization: `Bearer ${token}`
+ },
+ body: JSON.stringify({
+ template: template
+ })
+ })
+ .then(async (res) => {
+ if (!res.ok) throw await res.json();
+ return res.json();
+ })
+ .catch((err) => {
+ console.log(err);
+ error = err.detail;
+ return null;
+ });
+
+ if (error) {
+ throw error;
+ }
+
+ return res;
+};
+
export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
const data = new FormData();
data.append('file', file);
diff --git a/src/lib/components/documents/Settings/General.svelte b/src/lib/components/documents/Settings/General.svelte
index c3c7df5b..503cbc84 100644
--- a/src/lib/components/documents/Settings/General.svelte
+++ b/src/lib/components/documents/Settings/General.svelte
@@ -1,6 +1,12 @@
diff --git a/src/lib/utils/rag/index.ts b/src/lib/utils/rag/index.ts
index 6b219ef2..ba1f29f8 100644
--- a/src/lib/utils/rag/index.ts
+++ b/src/lib/utils/rag/index.ts
@@ -1,17 +1,21 @@
-export const RAGTemplate = (context: string, query: string) => {
- let template = `Use the following context as your learned knowledge, inside XML tags.
-
- [context]
-
-
- When answer to user:
- - If you don't know, just say that you don't know.
- - If you don't know when you are not sure, ask for clarification.
- Avoid mentioning that you obtained the information from the context.
- And answer according to the language of the user's question.
-
- Given the context information, answer the query.
- Query: [query]`;
+import { getRAGTemplate } from '$lib/apis/rag';
+
+export const RAGTemplate = async (token: string, context: string, query: string) => {
+ let template = await getRAGTemplate(token).catch(() => {
+ return `Use the following context as your learned knowledge, inside XML tags.
+
+ [context]
+
+
+ When answer to user:
+ - If you don't know, just say that you don't know.
+ - If you don't know when you are not sure, ask for clarification.
+ Avoid mentioning that you obtained the information from the context.
+ And answer according to the language of the user's question.
+
+ Given the context information, answer the query.
+ Query: [query]`;
+ });
template = template.replace(/\[context\]/g, context);
template = template.replace(/\[query\]/g, query);
diff --git a/src/routes/(app)/+page.svelte b/src/routes/(app)/+page.svelte
index 604cb544..1d91a614 100644
--- a/src/routes/(app)/+page.svelte
+++ b/src/routes/(app)/+page.svelte
@@ -266,7 +266,11 @@
console.log(contextString);
- history.messages[parentId].raContent = RAGTemplate(contextString, query);
+ history.messages[parentId].raContent = await RAGTemplate(
+ localStorage.token,
+ contextString,
+ query
+ );
history.messages[parentId].contexts = relevantContexts;
await tick();
processing = '';
diff --git a/src/routes/(app)/c/[id]/+page.svelte b/src/routes/(app)/c/[id]/+page.svelte
index aab03d74..b719ebf2 100644
--- a/src/routes/(app)/c/[id]/+page.svelte
+++ b/src/routes/(app)/c/[id]/+page.svelte
@@ -280,7 +280,11 @@
console.log(contextString);
- history.messages[parentId].raContent = RAGTemplate(contextString, query);
+ history.messages[parentId].raContent = await RAGTemplate(
+ localStorage.token,
+ contextString,
+ query
+ );
history.messages[parentId].contexts = relevantContexts;
await tick();
processing = '';