feat: add rag top k value setting

This commit is contained in:
Timothy J. Baek 2024-03-02 18:56:57 -08:00
parent 9694c6569f
commit 47a05a47b4
5 changed files with 123 additions and 38 deletions

View file

@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.TOP_K = 4
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL, model_name=app.state.RAG_EMBEDDING_MODEL,
@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
} }
class RAGTemplateForm(BaseModel): @app.get("/query/settings")
template: str async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
}
@app.post("/template/update") class QuerySettingsForm(BaseModel):
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)): k: Optional[int] = None
# TODO: check template requirements template: Optional[str] = None
app.state.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else RAG_TEMPLATE
) @app.post("/query/settings/update")
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
return {"status": True, "template": app.state.RAG_TEMPLATE} return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel): class QueryDocForm(BaseModel):
collection_name: str collection_name: str
query: str query: str
k: Optional[int] = 4 k: Optional[int] = None
@app.post("/query/doc") @app.post("/query/doc")
@ -240,7 +252,10 @@ def query_doc(
name=form_data.collection_name, name=form_data.collection_name,
embedding_function=app.state.sentence_transformer_ef, embedding_function=app.state.sentence_transformer_ef,
) )
result = collection.query(query_texts=[form_data.query], n_results=form_data.k) result = collection.query(
query_texts=[form_data.query],
n_results=form_data.k if form_data.k else app.state.TOP_K,
)
return result return result
except Exception as e: except Exception as e:
print(e) print(e)
@ -253,7 +268,7 @@ def query_doc(
class QueryCollectionsForm(BaseModel): class QueryCollectionsForm(BaseModel):
collection_names: List[str] collection_names: List[str]
query: str query: str
k: Optional[int] = 4 k: Optional[int] = None
def merge_and_sort_query_results(query_results, k): def merge_and_sort_query_results(query_results, k):
@ -317,13 +332,16 @@ def query_collection(
) )
result = collection.query( result = collection.query(
query_texts=[form_data.query], n_results=form_data.k query_texts=[form_data.query],
n_results=form_data.k if form_data.k else app.state.TOP_K,
) )
results.append(result) results.append(result)
except: except:
pass pass
return merge_and_sort_query_results(results, form_data.k) return merge_and_sort_query_results(
results, form_data.k if form_data.k else app.state.TOP_K
)
@app.post("/web") @app.post("/web")
@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]: ] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path) loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0): elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path) loader = TextLoader(file_path)
else: else:
loader = TextLoader(file_path) loader = TextLoader(file_path)

View file

@ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => {
return res?.template ?? ''; return res?.template ?? '';
}; };
export const updateRAGTemplate = async (token: string, template: string) => { export const getQuerySettings = async (token: string) => {
let error = null; let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/template/update`, { const res = await fetch(`${RAG_API_BASE_URL}/query/settings`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
type QuerySettings = {
k: number | null;
template: string | null;
};
export const updateQuerySettings = async (token: string, settings: QuerySettings) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/query/settings/update`, {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json', 'Content-Type': 'application/json',
Authorization: `Bearer ${token}` Authorization: `Bearer ${token}`
}, },
body: JSON.stringify({ body: JSON.stringify({
template: template ...settings
}) })
}) })
.then(async (res) => { .then(async (res) => {
@ -183,7 +215,7 @@ export const queryDoc = async (
token: string, token: string,
collection_name: string, collection_name: string,
query: string, query: string,
k: number k: number | null = null
) => { ) => {
let error = null; let error = null;

View file

@ -2,10 +2,10 @@
import { getDocs } from '$lib/apis/documents'; import { getDocs } from '$lib/apis/documents';
import { import {
getChunkParams, getChunkParams,
getRAGTemplate, getQuerySettings,
scanDocs, scanDocs,
updateChunkParams, updateChunkParams,
updateRAGTemplate updateQuerySettings
} from '$lib/apis/rag'; } from '$lib/apis/rag';
import { documents } from '$lib/stores'; import { documents } from '$lib/stores';
import { onMount } from 'svelte'; import { onMount } from 'svelte';
@ -18,7 +18,10 @@
let chunkSize = 0; let chunkSize = 0;
let chunkOverlap = 0; let chunkOverlap = 0;
let template = ''; let querySettings = {
template: '',
k: 4
};
const scanHandler = async () => { const scanHandler = async () => {
loading = true; loading = true;
@ -33,7 +36,7 @@
const submitHandler = async () => { const submitHandler = async () => {
const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap); const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
await updateRAGTemplate(localStorage.token, template); querySettings = await updateQuerySettings(localStorage.token, querySettings);
}; };
onMount(async () => { onMount(async () => {
@ -44,7 +47,7 @@
chunkOverlap = res.chunk_overlap; chunkOverlap = res.chunk_overlap;
} }
template = await getRAGTemplate(localStorage.token); querySettings = await getQuerySettings(localStorage.token);
}); });
</script> </script>
@ -156,10 +159,44 @@
</div> </div>
</div> </div>
<div class=" text-sm font-medium">Query Params</div>
<div class=" flex">
<div class=" flex w-full justify-between">
<div class="self-center text-xs font-medium flex-1">Top K</div>
<div class="self-center p-3">
<input
class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
type="number"
placeholder="Enter Top K"
bind:value={querySettings.k}
autocomplete="off"
min="0"
/>
</div>
</div>
<!-- <div class="flex w-full">
<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
<div class="self-center p-3">
<input
class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
type="number"
placeholder="Enter Chunk Overlap"
bind:value={chunkOverlap}
autocomplete="off"
min="0"
/>
</div>
</div> -->
</div>
<div> <div>
<div class=" mb-2.5 text-sm font-medium">RAG Template</div> <div class=" mb-2.5 text-sm font-medium">RAG Template</div>
<textarea <textarea
bind:value={template} bind:value={querySettings.template}
class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none" class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
rows="4" rows="4"
/> />

View file

@ -248,19 +248,17 @@
let relevantContexts = await Promise.all( let relevantContexts = await Promise.all(
docs.map(async (doc) => { docs.map(async (doc) => {
if (doc.type === 'collection') { if (doc.type === 'collection') {
return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( return await queryCollection(localStorage.token, doc.collection_names, query).catch(
(error) => { (error) => {
console.log(error); console.log(error);
return null; return null;
} }
); );
} else { } else {
return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
(error) => { console.log(error);
console.log(error); return null;
return null; });
}
);
} }
}) })
); );

View file

@ -261,19 +261,17 @@
let relevantContexts = await Promise.all( let relevantContexts = await Promise.all(
docs.map(async (doc) => { docs.map(async (doc) => {
if (doc.type === 'collection') { if (doc.type === 'collection') {
return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch( return await queryCollection(localStorage.token, doc.collection_names, query).catch(
(error) => { (error) => {
console.log(error); console.log(error);
return null; return null;
} }
); );
} else { } else {
return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch( return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
(error) => { console.log(error);
console.log(error); return null;
return null; });
}
);
} }
}) })
); );