Merge pull request #779 from open-webui/editable-chunk-params

feat: editable chunk params
This commit is contained in:
Timothy Jaeryang Baek 2024-02-18 01:49:24 -05:00 committed by GitHub
commit b993b66cfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 296 additions and 23 deletions

View file

@ -62,6 +62,7 @@ from config import (
CHROMA_CLIENT, CHROMA_CLIENT,
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
@ -71,6 +72,11 @@ from constants import ERROR_MESSAGES
app = FastAPI() app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
@ -92,7 +98,7 @@ class StoreWebForm(CollectionNameForm):
def store_data_in_vector_db(data, collection_name) -> bool: def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
) )
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
@ -116,7 +122,60 @@ def store_data_in_vector_db(data, collection_name) -> bool:
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return {"status": True} return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/chunk")
async def get_chunk_params(user=Depends(get_admin_user)):
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
@app.post("/chunk/update")
async def update_chunk_params(
form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
):
app.state.CHUNK_SIZE = form_data.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk_overlap
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
}
class RAGTemplateForm(BaseModel):
template: str
@app.post("/template/update")
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
# TODO: check template requirements
app.state.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else RAG_TEMPLATE
)
return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel): class QueryDocForm(BaseModel):

View file

@ -144,6 +144,21 @@ CHROMA_CLIENT = chromadb.PersistentClient(
CHUNK_SIZE = 1500 CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100 CHUNK_OVERLAP = 100
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
#################################### ####################################
# Transcribe # Transcribe
#################################### ####################################

View file

@ -1,5 +1,120 @@
import { RAG_API_BASE_URL } from '$lib/constants'; import { RAG_API_BASE_URL } from '$lib/constants';
export const getChunkParams = async (token: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/chunk`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateChunkParams = async (token: string, size: number, overlap: number) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
chunk_size: size,
chunk_overlap: overlap
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const getRAGTemplate = async (token: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/template`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res?.template ?? '';
};
export const updateRAGTemplate = async (token: string, template: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
template: template
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => { export const uploadDocToVectorDB = async (token: string, collection_name: string, file: File) => {
const data = new FormData(); const data = new FormData();
data.append('file', file); data.append('file', file);

View file

@ -1,6 +1,12 @@
<script lang="ts"> <script lang="ts">
import { getDocs } from '$lib/apis/documents'; import { getDocs } from '$lib/apis/documents';
import { scanDocs } from '$lib/apis/rag'; import {
getChunkParams,
getRAGTemplate,
scanDocs,
updateChunkParams,
updateRAGTemplate
} from '$lib/apis/rag';
import { documents } from '$lib/stores'; import { documents } from '$lib/stores';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
import toast from 'svelte-french-toast'; import toast from 'svelte-french-toast';
@ -9,6 +15,11 @@
let loading = false; let loading = false;
let chunkSize = 0;
let chunkOverlap = 0;
let template = '';
const scanHandler = async () => { const scanHandler = async () => {
loading = true; loading = true;
const res = await scanDocs(localStorage.token); const res = await scanDocs(localStorage.token);
@ -20,13 +31,27 @@
} }
}; };
onMount(async () => {}); const submitHandler = async () => {
const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
await updateRAGTemplate(localStorage.token, template);
};
onMount(async () => {
const res = await getChunkParams(localStorage.token);
if (res) {
chunkSize = res.chunk_size;
chunkOverlap = res.chunk_overlap;
}
template = await getRAGTemplate(localStorage.token);
});
</script> </script>
<form <form
class="flex flex-col h-full justify-between space-y-3 text-sm" class="flex flex-col h-full justify-between space-y-3 text-sm"
on:submit|preventDefault={() => { on:submit|preventDefault={() => {
// console.log('submit'); submitHandler();
saveHandler(); saveHandler();
}} }}
> >
@ -93,14 +118,61 @@
</button> </button>
</div> </div>
</div> </div>
<hr class=" dark:border-gray-700" />
<div class=" ">
<div class=" text-sm font-medium">Chunk Params</div>
<div class=" flex">
<div class=" flex w-full justify-between">
<div class="self-center text-xs font-medium min-w-fit">Chunk Size</div>
<div class="self-center p-3">
<input
class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
type="number"
placeholder="Enter Chunk Size"
bind:value={chunkSize}
autocomplete="off"
min="0"
/>
</div>
</div> </div>
<!-- <div class="flex justify-end pt-3 text-sm font-medium"> <div class="flex w-full">
<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
<div class="self-center p-3">
<input
class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
type="number"
placeholder="Enter Chunk Overlap"
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div>
</div>
<div>
<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
<textarea
bind:value={template}
class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
rows="4"
/>
</div>
</div>
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
<button <button
class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded" class=" px-4 py-2 bg-emerald-600 hover:bg-emerald-700 text-gray-100 transition rounded"
type="submit" type="submit"
> >
Save Save
</button> </button>
</div> --> </div>
</form> </form>

View file

@ -1,5 +1,8 @@
export const RAGTemplate = (context: string, query: string) => { import { getRAGTemplate } from '$lib/apis/rag';
let template = `Use the following context as your learned knowledge, inside <context></context> XML tags.
export const RAGTemplate = async (token: string, context: string, query: string) => {
let template = await getRAGTemplate(token).catch(() => {
return `Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
[context] [context]
</context> </context>
@ -12,6 +15,7 @@ export const RAGTemplate = (context: string, query: string) => {
Given the context information, answer the query. Given the context information, answer the query.
Query: [query]`; Query: [query]`;
});
template = template.replace(/\[context\]/g, context); template = template.replace(/\[context\]/g, context);
template = template.replace(/\[query\]/g, query); template = template.replace(/\[query\]/g, query);

View file

@ -266,7 +266,11 @@
console.log(contextString); console.log(contextString);
history.messages[parentId].raContent = RAGTemplate(contextString, query); history.messages[parentId].raContent = await RAGTemplate(
localStorage.token,
contextString,
query
);
history.messages[parentId].contexts = relevantContexts; history.messages[parentId].contexts = relevantContexts;
await tick(); await tick();
processing = ''; processing = '';

View file

@ -280,7 +280,11 @@
console.log(contextString); console.log(contextString);
history.messages[parentId].raContent = RAGTemplate(contextString, query); history.messages[parentId].raContent = await RAGTemplate(
localStorage.token,
contextString,
query
);
history.messages[parentId].contexts = relevantContexts; history.messages[parentId].contexts = relevantContexts;
await tick(); await tick();
processing = ''; processing = '';