forked from open-webui/open-webui
feat: add rag top k value setting
This commit is contained in:
parent
9694c6569f
commit
47a05a47b4
5 changed files with 123 additions and 38 deletions
|
@ -79,6 +79,8 @@ app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
|
app.state.TOP_K = 4
|
||||||
|
|
||||||
app.state.sentence_transformer_ef = (
|
app.state.sentence_transformer_ef = (
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
@ -210,23 +212,33 @@ async def get_rag_template(user=Depends(get_current_user)):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RAGTemplateForm(BaseModel):
|
@app.get("/query/settings")
|
||||||
template: str
|
async def get_query_settings(user=Depends(get_admin_user)):
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"template": app.state.RAG_TEMPLATE,
|
||||||
|
"k": app.state.TOP_K,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.post("/template/update")
|
class QuerySettingsForm(BaseModel):
|
||||||
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
|
k: Optional[int] = None
|
||||||
# TODO: check template requirements
|
template: Optional[str] = None
|
||||||
app.state.RAG_TEMPLATE = (
|
|
||||||
form_data.template if form_data.template != "" else RAG_TEMPLATE
|
|
||||||
)
|
@app.post("/query/settings/update")
|
||||||
|
async def update_query_settings(
|
||||||
|
form_data: QuerySettingsForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
|
||||||
|
app.state.TOP_K = form_data.k if form_data.k else 4
|
||||||
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
return {"status": True, "template": app.state.RAG_TEMPLATE}
|
||||||
|
|
||||||
|
|
||||||
class QueryDocForm(BaseModel):
|
class QueryDocForm(BaseModel):
|
||||||
collection_name: str
|
collection_name: str
|
||||||
query: str
|
query: str
|
||||||
k: Optional[int] = 4
|
k: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@app.post("/query/doc")
|
@app.post("/query/doc")
|
||||||
|
@ -240,7 +252,10 @@ def query_doc(
|
||||||
name=form_data.collection_name,
|
name=form_data.collection_name,
|
||||||
embedding_function=app.state.sentence_transformer_ef,
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
)
|
)
|
||||||
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
result = collection.query(
|
||||||
|
query_texts=[form_data.query],
|
||||||
|
n_results=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
@ -253,7 +268,7 @@ def query_doc(
|
||||||
class QueryCollectionsForm(BaseModel):
|
class QueryCollectionsForm(BaseModel):
|
||||||
collection_names: List[str]
|
collection_names: List[str]
|
||||||
query: str
|
query: str
|
||||||
k: Optional[int] = 4
|
k: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
def merge_and_sort_query_results(query_results, k):
|
def merge_and_sort_query_results(query_results, k):
|
||||||
|
@ -317,13 +332,16 @@ def query_collection(
|
||||||
)
|
)
|
||||||
|
|
||||||
result = collection.query(
|
result = collection.query(
|
||||||
query_texts=[form_data.query], n_results=form_data.k
|
query_texts=[form_data.query],
|
||||||
|
n_results=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
)
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return merge_and_sort_query_results(results, form_data.k)
|
return merge_and_sort_query_results(
|
||||||
|
results, form_data.k if form_data.k else app.state.TOP_K
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/web")
|
@app.post("/web")
|
||||||
|
@ -423,7 +441,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
|
||||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
] or file_ext in ["xls", "xlsx"]:
|
] or file_ext in ["xls", "xlsx"]:
|
||||||
loader = UnstructuredExcelLoader(file_path)
|
loader = UnstructuredExcelLoader(file_path)
|
||||||
elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0):
|
elif file_ext in known_source_ext or (
|
||||||
|
file_content_type and file_content_type.find("text/") >= 0
|
||||||
|
):
|
||||||
loader = TextLoader(file_path)
|
loader = TextLoader(file_path)
|
||||||
else:
|
else:
|
||||||
loader = TextLoader(file_path)
|
loader = TextLoader(file_path)
|
||||||
|
|
|
@ -85,17 +85,49 @@ export const getRAGTemplate = async (token: string) => {
|
||||||
return res?.template ?? '';
|
return res?.template ?? '';
|
||||||
};
|
};
|
||||||
|
|
||||||
export const updateRAGTemplate = async (token: string, template: string) => {
|
export const getQuerySettings = async (token: string) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
const res = await fetch(`${RAG_API_BASE_URL}/template/update`, {
|
const res = await fetch(`${RAG_API_BASE_URL}/query/settings`, {
|
||||||
|
method: 'GET',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.then(async (res) => {
|
||||||
|
if (!res.ok) throw await res.json();
|
||||||
|
return res.json();
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log(err);
|
||||||
|
error = err.detail;
|
||||||
|
return null;
|
||||||
|
});
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
};
|
||||||
|
|
||||||
|
type QuerySettings = {
|
||||||
|
k: number | null;
|
||||||
|
template: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const updateQuerySettings = async (token: string, settings: QuerySettings) => {
|
||||||
|
let error = null;
|
||||||
|
|
||||||
|
const res = await fetch(`${RAG_API_BASE_URL}/query/settings/update`, {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
Authorization: `Bearer ${token}`
|
Authorization: `Bearer ${token}`
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
template: template
|
...settings
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.then(async (res) => {
|
.then(async (res) => {
|
||||||
|
@ -183,7 +215,7 @@ export const queryDoc = async (
|
||||||
token: string,
|
token: string,
|
||||||
collection_name: string,
|
collection_name: string,
|
||||||
query: string,
|
query: string,
|
||||||
k: number
|
k: number | null = null
|
||||||
) => {
|
) => {
|
||||||
let error = null;
|
let error = null;
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,10 @@
|
||||||
import { getDocs } from '$lib/apis/documents';
|
import { getDocs } from '$lib/apis/documents';
|
||||||
import {
|
import {
|
||||||
getChunkParams,
|
getChunkParams,
|
||||||
getRAGTemplate,
|
getQuerySettings,
|
||||||
scanDocs,
|
scanDocs,
|
||||||
updateChunkParams,
|
updateChunkParams,
|
||||||
updateRAGTemplate
|
updateQuerySettings
|
||||||
} from '$lib/apis/rag';
|
} from '$lib/apis/rag';
|
||||||
import { documents } from '$lib/stores';
|
import { documents } from '$lib/stores';
|
||||||
import { onMount } from 'svelte';
|
import { onMount } from 'svelte';
|
||||||
|
@ -18,7 +18,10 @@
|
||||||
let chunkSize = 0;
|
let chunkSize = 0;
|
||||||
let chunkOverlap = 0;
|
let chunkOverlap = 0;
|
||||||
|
|
||||||
let template = '';
|
let querySettings = {
|
||||||
|
template: '',
|
||||||
|
k: 4
|
||||||
|
};
|
||||||
|
|
||||||
const scanHandler = async () => {
|
const scanHandler = async () => {
|
||||||
loading = true;
|
loading = true;
|
||||||
|
@ -33,7 +36,7 @@
|
||||||
|
|
||||||
const submitHandler = async () => {
|
const submitHandler = async () => {
|
||||||
const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
|
const res = await updateChunkParams(localStorage.token, chunkSize, chunkOverlap);
|
||||||
await updateRAGTemplate(localStorage.token, template);
|
querySettings = await updateQuerySettings(localStorage.token, querySettings);
|
||||||
};
|
};
|
||||||
|
|
||||||
onMount(async () => {
|
onMount(async () => {
|
||||||
|
@ -44,7 +47,7 @@
|
||||||
chunkOverlap = res.chunk_overlap;
|
chunkOverlap = res.chunk_overlap;
|
||||||
}
|
}
|
||||||
|
|
||||||
template = await getRAGTemplate(localStorage.token);
|
querySettings = await getQuerySettings(localStorage.token);
|
||||||
});
|
});
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
|
@ -156,10 +159,44 @@
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<div class=" text-sm font-medium">Query Params</div>
|
||||||
|
|
||||||
|
<div class=" flex">
|
||||||
|
<div class=" flex w-full justify-between">
|
||||||
|
<div class="self-center text-xs font-medium flex-1">Top K</div>
|
||||||
|
|
||||||
|
<div class="self-center p-3">
|
||||||
|
<input
|
||||||
|
class=" w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
|
||||||
|
type="number"
|
||||||
|
placeholder="Enter Top K"
|
||||||
|
bind:value={querySettings.k}
|
||||||
|
autocomplete="off"
|
||||||
|
min="0"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- <div class="flex w-full">
|
||||||
|
<div class=" self-center text-xs font-medium min-w-fit">Chunk Overlap</div>
|
||||||
|
|
||||||
|
<div class="self-center p-3">
|
||||||
|
<input
|
||||||
|
class="w-full rounded py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none border border-gray-100 dark:border-gray-600"
|
||||||
|
type="number"
|
||||||
|
placeholder="Enter Chunk Overlap"
|
||||||
|
bind:value={chunkOverlap}
|
||||||
|
autocomplete="off"
|
||||||
|
min="0"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div> -->
|
||||||
|
</div>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
|
<div class=" mb-2.5 text-sm font-medium">RAG Template</div>
|
||||||
<textarea
|
<textarea
|
||||||
bind:value={template}
|
bind:value={querySettings.template}
|
||||||
class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
|
class="w-full rounded p-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none resize-none"
|
||||||
rows="4"
|
rows="4"
|
||||||
/>
|
/>
|
||||||
|
|
|
@ -248,19 +248,17 @@
|
||||||
let relevantContexts = await Promise.all(
|
let relevantContexts = await Promise.all(
|
||||||
docs.map(async (doc) => {
|
docs.map(async (doc) => {
|
||||||
if (doc.type === 'collection') {
|
if (doc.type === 'collection') {
|
||||||
return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
|
return await queryCollection(localStorage.token, doc.collection_names, query).catch(
|
||||||
(error) => {
|
(error) => {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
|
return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
|
||||||
(error) => {
|
|
||||||
console.log(error);
|
console.log(error);
|
||||||
return null;
|
return null;
|
||||||
}
|
});
|
||||||
);
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
|
@ -261,19 +261,17 @@
|
||||||
let relevantContexts = await Promise.all(
|
let relevantContexts = await Promise.all(
|
||||||
docs.map(async (doc) => {
|
docs.map(async (doc) => {
|
||||||
if (doc.type === 'collection') {
|
if (doc.type === 'collection') {
|
||||||
return await queryCollection(localStorage.token, doc.collection_names, query, 4).catch(
|
return await queryCollection(localStorage.token, doc.collection_names, query).catch(
|
||||||
(error) => {
|
(error) => {
|
||||||
console.log(error);
|
console.log(error);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
return await queryDoc(localStorage.token, doc.collection_name, query, 4).catch(
|
return await queryDoc(localStorage.token, doc.collection_name, query).catch((error) => {
|
||||||
(error) => {
|
|
||||||
console.log(error);
|
console.log(error);
|
||||||
return null;
|
return null;
|
||||||
}
|
});
|
||||||
);
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
);
|
);
|
||||||
|
|
Loading…
Reference in a new issue