forked from open-webui/open-webui
refac: rag pipeline
This commit is contained in:
parent
8f1563a7a5
commit
ce9a5d12e0
3 changed files with 179 additions and 154 deletions
|
@ -47,9 +47,11 @@ from apps.web.models.documents import (
|
|||
|
||||
from apps.rag.utils import (
|
||||
get_model_path,
|
||||
query_embeddings_doc,
|
||||
get_embeddings_function,
|
||||
query_embeddings_collection,
|
||||
get_embedding_function,
|
||||
query_doc,
|
||||
query_doc_with_hybrid_search,
|
||||
query_collection,
|
||||
query_collection_with_hybrid_search,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
|
@ -147,6 +149,15 @@ update_reranking_model(
|
|||
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||
)
|
||||
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
|
@ -227,6 +238,14 @@ async def update_embedding_config(
|
|||
|
||||
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
||||
|
||||
app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
|
@ -367,27 +386,22 @@ def query_doc_handler(
|
|||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
embeddings_function = get_embeddings_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
embeddings_function=embeddings_function,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
hybrid_search=(
|
||||
form_data.hybrid
|
||||
if form_data.hybrid
|
||||
else app.state.ENABLE_RAG_HYBRID_SEARCH
|
||||
),
|
||||
)
|
||||
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
||||
return query_doc_with_hybrid_search(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
)
|
||||
else:
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
|
@ -410,27 +424,23 @@ def query_collection_handler(
|
|||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
embeddings_function = get_embeddings_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.OPENAI_API_KEY,
|
||||
app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
if app.state.ENABLE_RAG_HYBRID_SEARCH:
|
||||
return query_collection_with_hybrid_search(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
)
|
||||
else:
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
embeddings_function=app.state.EMBEDDING_FUNCTION,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
|
||||
return query_embeddings_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
|
||||
embeddings_function=embeddings_function,
|
||||
reranking_function=app.state.sentence_transformer_rf,
|
||||
hybrid_search=(
|
||||
form_data.hybrid
|
||||
if form_data.hybrid
|
||||
else app.state.ENABLE_RAG_HYBRID_SEARCH
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
|
@ -508,7 +518,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
embedding_func = get_embeddings_function(
|
||||
embedding_func = get_embedding_function(
|
||||
app.state.RAG_EMBEDDING_ENGINE,
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
app.state.sentence_transformer_ef,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue