Merge pull request #1693 from buroa/buroa/hybrid-search

feat: hybrid search with reranking
This commit is contained in:
Timothy Jaeryang Baek 2024-04-25 13:12:18 -07:00 committed by GitHub
commit 5ee2f1729a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 655 additions and 176 deletions

View file

@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.122] - 2024-04-24
- **🌟 Enhanced RAG Pipeline**: Added hybrid searching with `BM25`, reranking using `CrossEncoder`, and relevance score thresholds.
## [0.1.121] - 2024-04-24 ## [0.1.121] - 2024-04-24
### Fixed ### Fixed

View file

@ -8,8 +8,9 @@ ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL=""
######## WebUI frontend ######## ######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
@ -30,6 +31,7 @@ ARG USE_CUDA
ARG USE_OLLAMA ARG USE_OLLAMA
ARG USE_CUDA_VER ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL
## Basis ## ## Basis ##
ENV ENV=prod \ ENV ENV=prod \
@ -38,7 +40,8 @@ ENV ENV=prod \
USE_OLLAMA_DOCKER=${USE_OLLAMA} \ USE_OLLAMA_DOCKER=${USE_OLLAMA} \
USE_CUDA_DOCKER=${USE_CUDA} \ USE_CUDA_DOCKER=${USE_CUDA} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \ USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL} \
USE_RERANKING_MODEL_DOCKER=${USE_RERANKING_MODEL}
## Basis URL Config ## ## Basis URL Config ##
ENV OLLAMA_BASE_URL="/ollama" \ ENV OLLAMA_BASE_URL="/ollama" \
@ -62,8 +65,11 @@ ENV WHISPER_MODEL="base" \
## RAG Embedding model settings ## ## RAG Embedding model settings ##
ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \ RAG_RERANKING_MODEL="$USE_RERANKING_MODEL_DOCKER" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models"
## Hugging Face download cache ##
ENV HF_HOME="/app/backend/data/cache/embedding/models"
#### Other models ########################################################## #### Other models ##########################################################
WORKDIR /app/backend WORKDIR /app/backend

View file

@ -39,8 +39,6 @@ import json
import sentence_transformers import sentence_transformers
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from apps.web.models.documents import ( from apps.web.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
@ -48,9 +46,10 @@ from apps.web.models.documents import (
) )
from apps.rag.utils import ( from apps.rag.utils import (
get_model_path,
query_embeddings_doc, query_embeddings_doc,
query_embeddings_function,
query_embeddings_collection, query_embeddings_collection,
generate_openai_embeddings,
) )
from utils.misc import ( from utils.misc import (
@ -60,13 +59,20 @@ from utils.misc import (
extract_folders_after_data_docs, extract_folders_after_data_docs,
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
DEVICE_TYPE, DEVICE_TYPE,
@ -83,14 +89,14 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI() app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.TOP_K = 4 app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
@ -98,16 +104,48 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
if app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_ef = None
def update_reranking_model(
reranking_model: str,
update_model: bool = False,
):
if reranking_model:
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
)
else:
app.state.sentence_transformer_rf = None
update_embedding_model(
app.state.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
@ -134,6 +172,7 @@ async def get_status():
"template": app.state.RAG_TEMPLATE, "template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
} }
@ -150,6 +189,11 @@ async def get_embedding_config(user=Depends(get_admin_user)):
} }
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
class OpenAIConfigForm(BaseModel): class OpenAIConfigForm(BaseModel):
url: str url: str
key: str key: str
@ -170,22 +214,14 @@ async def update_embedding_config(
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None
if form_data.openai_config != None: if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key app.state.OPENAI_API_KEY = form_data.openai_config.key
else:
sentence_transformer_ef = sentence_transformers.SentenceTransformer( update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE,
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef
return { return {
"status": True, "status": True,
@ -196,7 +232,6 @@ async def update_embedding_config(
"key": app.state.OPENAI_API_KEY, "key": app.state.OPENAI_API_KEY,
}, },
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating embedding model: {e}") log.exception(f"Problem updating embedding model: {e}")
raise HTTPException( raise HTTPException(
@ -205,6 +240,34 @@ async def update_embedding_config(
) )
class RerankingModelUpdateForm(BaseModel):
reranking_model: str
@app.post("/reranking/update")
async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/config") @app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)): async def get_rag_config(user=Depends(get_admin_user)):
return { return {
@ -257,11 +320,13 @@ async def get_query_settings(user=Depends(get_admin_user)):
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K, "k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
} }
class QuerySettingsForm(BaseModel): class QuerySettingsForm(BaseModel):
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None
template: Optional[str] = None template: Optional[str] = None
@ -271,6 +336,7 @@ async def update_query_settings(
): ):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4 app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
return {"status": True, "template": app.state.RAG_TEMPLATE} return {"status": True, "template": app.state.RAG_TEMPLATE}
@ -278,6 +344,7 @@ class QueryDocForm(BaseModel):
collection_name: str collection_name: str
query: str query: str
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None
@app.post("/query/doc") @app.post("/query/doc")
@ -286,34 +353,22 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": embeddings_function = query_embeddings_function(
query_embeddings = app.state.sentence_transformer_ef.encode( app.state.RAG_EMBEDDING_ENGINE,
form_data.query app.state.RAG_EMBEDDING_MODEL,
).tolist() app.state.sentence_transformer_ef,
elif app.state.RAG_EMBEDDING_ENGINE == "ollama": app.state.OPENAI_API_KEY,
query_embeddings = generate_ollama_embeddings( app.state.OPENAI_API_BASE_URL,
GenerateEmbeddingsForm( )
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
return query_embeddings_doc( return query_embeddings_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -326,6 +381,7 @@ class QueryCollectionsForm(BaseModel):
collection_names: List[str] collection_names: List[str]
query: str query: str
k: Optional[int] = None k: Optional[int] = None
r: Optional[float] = None
@app.post("/query/collection") @app.post("/query/collection")
@ -334,33 +390,22 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": embeddings_function = query_embeddings_function(
query_embeddings = app.state.sentence_transformer_ef.encode( app.state.RAG_EMBEDDING_ENGINE,
form_data.query app.state.RAG_EMBEDDING_MODEL,
).tolist() app.state.sentence_transformer_ef,
elif app.state.RAG_EMBEDDING_ENGINE == "ollama": app.state.OPENAI_API_KEY,
query_embeddings = generate_ollama_embeddings( app.state.OPENAI_API_BASE_URL,
GenerateEmbeddingsForm( )
**{
"model": app.state.RAG_EMBEDDING_MODEL,
"prompt": form_data.query,
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
return query_embeddings_collection( return query_embeddings_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query_embeddings=query_embeddings, query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
embeddings_function=embeddings_function,
reranking_function=app.state.sentence_transformer_rf,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -427,8 +472,6 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"store_docs_in_vector_db {docs} {collection_name}") log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
texts = list(map(lambda x: x.replace("\n", " "), texts))
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
@ -440,27 +483,16 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "": embedding_func = query_embeddings_function(
embeddings = app.state.sentence_transformer_ef.encode(texts).tolist() app.state.RAG_EMBEDDING_ENGINE,
elif app.state.RAG_EMBEDDING_ENGINE == "ollama": app.state.RAG_EMBEDDING_MODEL,
embeddings = [ app.state.sentence_transformer_ef,
generate_ollama_embeddings( app.state.OPENAI_API_KEY,
GenerateEmbeddingsForm( app.state.OPENAI_API_BASE_URL,
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text} )
)
) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
for text in texts embeddings = embedding_func(embedding_texts)
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.OPENAI_API_KEY,
url=app.state.OPENAI_API_BASE_URL,
)
for text in texts
]
for batch in create_batches( for batch in create_batches(
api=CHROMA_CLIENT, api=CHROMA_CLIENT,

View file

@ -1,3 +1,4 @@
import os
import logging import logging
import requests import requests
@ -8,6 +9,15 @@ from apps.ollama.main import (
GenerateEmbeddingsForm, GenerateEmbeddingsForm,
) )
from huggingface_hub import snapshot_download
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import (
ContextualCompressionRetriever,
EnsembleRetriever,
)
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -15,18 +25,53 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int): def query_embeddings_doc(
collection_name: str,
query: str,
k: int,
r: float,
embeddings_function,
reranking_function,
):
try: try:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(name=collection_name) collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query( documents = collection.get() # get all documents
query_embeddings=[query_embeddings], bm25_retriever = BM25Retriever.from_texts(
n_results=k, texts=documents.get("documents"),
metadatas=documents.get("metadatas"),
)
bm25_retriever.k = k
chroma_retriever = ChromaRetriever(
collection=collection,
embeddings_function=embeddings_function,
top_n=k,
) )
log.info(f"query_embeddings_doc:result {result}") ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, chroma_retriever], weights=[0.5, 0.5]
)
compressor = RerankCompressor(
embeddings_function=embeddings_function,
reranking_function=reranking_function,
r_score=r,
top_n=k,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
result = compression_retriever.invoke(query)
result = {
"distances": [[d.metadata.get("score") for d in result]],
"documents": [[d.page_content for d in result]],
"metadatas": [[d.metadata for d in result]],
}
return result return result
except Exception as e: except Exception as e:
raise e raise e
@ -34,63 +79,65 @@ def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k:
def merge_and_sort_query_results(query_results, k): def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data # Initialize lists to store combined data
combined_ids = []
combined_distances = [] combined_distances = []
combined_metadatas = []
combined_documents = [] combined_documents = []
combined_metadatas = []
# Combine data from each dictionary
for data in query_results: for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0]) combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0]) combined_documents.extend(data["documents"][0])
combined_metadatas.extend(data["metadatas"][0])
# Create a list of tuples (distance, id, metadata, document) # Create a list of tuples (distance, document, metadata)
combined = list( combined = list(zip(combined_distances, combined_documents, combined_metadatas))
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances # Sort the list based on distances
combined.sort(key=lambda x: x[0]) combined.sort(key=lambda x: x[0])
# Unzip the sorted list # We don't have anything :-(
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined) if not combined:
sorted_distances = []
sorted_documents = []
sorted_metadatas = []
else:
# Unzip the sorted list
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined)
# Slicing the lists to include only k elements # Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k] sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k] sorted_documents = list(sorted_documents)[:k]
sorted_metadatas = list(sorted_metadatas)[:k] sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary # Create the output dictionary
merged_query_results = { result = {
"ids": [sorted_ids],
"distances": [sorted_distances], "distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents], "documents": [sorted_documents],
"embeddings": None, "metadatas": [sorted_metadatas],
"uris": None,
"data": None,
} }
return merged_query_results return result
def query_embeddings_collection( def query_embeddings_collection(
collection_names: List[str], query: str, query_embeddings, k: int collection_names: List[str],
query: str,
k: int,
r: float,
embeddings_function,
reranking_function,
): ):
results = [] results = []
log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names: for collection_name in collection_names:
try: try:
result = query_embeddings_doc( result = query_embeddings_doc(
collection_name=collection_name, collection_name=collection_name,
query=query, query=query,
query_embeddings=query_embeddings,
k=k, k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
) )
results.append(result) results.append(result)
except: except:
@ -105,19 +152,57 @@ def rag_template(template: str, context: str, query: str):
return template return template
def rag_messages( def query_embeddings_function(
docs,
messages,
template,
k,
embedding_engine, embedding_engine,
embedding_model, embedding_model,
embedding_function, embedding_function,
openai_key, openai_key,
openai_url, openai_url,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]:
if embedding_engine == "ollama":
func = lambda query: generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
func = lambda query: generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
def generate_multiple(query, f):
if isinstance(query, list):
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
def rag_messages(
docs,
messages,
template,
k,
r,
embedding_engine,
embedding_model,
embedding_function,
reranking_function,
openai_key,
openai_url,
): ):
log.debug( log.debug(
f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}" f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {reranking_function} {openai_key} {openai_url}"
) )
last_user_message_idx = None last_user_message_idx = None
@ -145,62 +230,66 @@ def rag_messages(
content_type = None content_type = None
query = "" query = ""
embeddings_function = query_embeddings_function(
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
)
extracted_collections = []
relevant_contexts = [] relevant_contexts = []
for doc in docs: for doc in docs:
context = None context = None
try: collection = doc.get("collection_name")
if collection:
collection = [collection]
else:
collection = doc.get("collection_names", [])
collection = set(collection).difference(extracted_collections)
if not collection:
log.debug(f"skipping {doc} as it has already been extracted")
continue
try:
if doc["type"] == "text": if doc["type"] == "text":
context = doc["content"] context = doc["content"]
elif doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
r=r,
embeddings_function=embeddings_function,
reranking_function=reranking_function,
)
else: else:
if embedding_engine == "": context = query_embeddings_doc(
query_embeddings = embedding_function.encode(query).tolist() collection_name=doc["collection_name"],
elif embedding_engine == "ollama": query=query,
query_embeddings = generate_ollama_embeddings( k=k,
GenerateEmbeddingsForm( r=r,
**{ embeddings_function=embeddings_function,
"model": embedding_model, reranking_function=reranking_function,
"prompt": query, )
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
query_embeddings=query_embeddings,
k=k,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query=query,
query_embeddings=query_embeddings,
k=k,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context = None context = None
relevant_contexts.append(context) if context:
relevant_contexts.append(context)
log.debug(f"relevant_contexts: {relevant_contexts}") extracted_collections.extend(collection)
context_string = "" context_string = ""
for context in relevant_contexts: for context in relevant_contexts:
if context: items = context["documents"][0]
context_string += " ".join(context["documents"][0]) + "\n" context_string += "\n\n".join(items)
context_string = context_string.strip()
ra_content = rag_template( ra_content = rag_template(
template=template, template=template,
@ -208,6 +297,8 @@ def rag_messages(
query=query, query=query,
) )
log.debug(f"ra_content: {ra_content}")
if content_type == "list": if content_type == "list":
new_content = [] new_content = []
for content_item in user_message["content"]: for content_item in user_message["content"]:
@ -229,6 +320,44 @@ def rag_messages(
return messages return messages
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"model: {model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
def generate_openai_embeddings( def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1" model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
): ):
@ -250,3 +379,97 @@ def generate_openai_embeddings(
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
from typing import Any
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
class ChromaRetriever(BaseRetriever):
collection: Any
embeddings_function: Any
top_n: int
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
query_embeddings = self.embeddings_function(query)
results = self.collection.query(
query_embeddings=[query_embeddings],
n_results=self.top_n,
)
ids = results["ids"][0]
metadatas = results["metadatas"][0]
documents = results["documents"][0]
return [
Document(
metadata=metadatas[idx],
page_content=documents[idx],
)
for idx in range(len(ids))
]
import operator
from typing import Optional, Sequence
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embeddings_function: Any
reranking_function: Any
r_score: float
top_n: int
class Config:
extra = Extra.forbid
arbitrary_types_allowed = True
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
if self.reranking_function:
scores = self.reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
else:
query_embedding = self.embeddings_function(query)
document_embedding = self.embeddings_function(
[doc.page_content for doc in documents]
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
docs_with_scores = list(zip(documents, scores.tolist()))
if self.r_score:
docs_with_scores = [
(d, s) for d, s in docs_with_scores if s >= self.r_score
]
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
metadata["score"] = doc_score
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results

View file

@ -420,6 +420,9 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
RAG_EMBEDDING_MODEL = os.environ.get( RAG_EMBEDDING_MODEL = os.environ.get(
@ -427,10 +430,26 @@ RAG_EMBEDDING_MODEL = os.environ.get(
) )
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if not RAG_RERANKING_MODEL == "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
@ -439,16 +458,15 @@ if USE_CUDA.lower() == "true":
else: else:
DEVICE_TYPE = "cpu" DEVICE_TYPE = "cpu"
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),
) )
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
[context] [context]
</context> </context>
@ -462,6 +480,8 @@ And answer according to the language of the user's question.
Given the context information, answer the query. Given the context information, answer the query.
Query: [query]""" Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)

View file

@ -120,9 +120,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
data["messages"], data["messages"],
rag_app.state.RAG_TEMPLATE, rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K, rag_app.state.TOP_K,
rag_app.state.RELEVANCE_THRESHOLD,
rag_app.state.RAG_EMBEDDING_ENGINE, rag_app.state.RAG_EMBEDDING_ENGINE,
rag_app.state.RAG_EMBEDDING_MODEL, rag_app.state.RAG_EMBEDDING_MODEL,
rag_app.state.sentence_transformer_ef, rag_app.state.sentence_transformer_ef,
rag_app.state.sentence_transformer_rf,
rag_app.state.OPENAI_API_KEY, rag_app.state.OPENAI_API_KEY,
rag_app.state.OPENAI_API_BASE_URL, rag_app.state.OPENAI_API_BASE_URL,
) )

View file

@ -123,6 +123,7 @@ export const getQuerySettings = async (token: string) => {
type QuerySettings = { type QuerySettings = {
k: number | null; k: number | null;
r: number | null;
template: string | null; template: string | null;
}; };
@ -413,3 +414,64 @@ export const updateEmbeddingConfig = async (token: string, payload: EmbeddingMod
return res; return res;
}; };
export const getRerankingConfig = async (token: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/reranking`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
type RerankingModelUpdateForm = {
reranking_model: string;
};
export const updateRerankingConfig = async (token: string, payload: RerankingModelUpdateForm) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/reranking/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
...payload
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};

View file

@ -8,7 +8,9 @@
updateQuerySettings, updateQuerySettings,
resetVectorDB, resetVectorDB,
getEmbeddingConfig, getEmbeddingConfig,
updateEmbeddingConfig updateEmbeddingConfig,
getRerankingConfig,
updateRerankingConfig
} from '$lib/apis/rag'; } from '$lib/apis/rag';
import { documents, models } from '$lib/stores'; import { documents, models } from '$lib/stores';
@ -23,11 +25,13 @@
let scanDirLoading = false; let scanDirLoading = false;
let updateEmbeddingModelLoading = false; let updateEmbeddingModelLoading = false;
let updateRerankingModelLoading = false;
let showResetConfirm = false; let showResetConfirm = false;
let embeddingEngine = ''; let embeddingEngine = '';
let embeddingModel = ''; let embeddingModel = '';
let rerankingModel = '';
let OpenAIKey = ''; let OpenAIKey = '';
let OpenAIUrl = ''; let OpenAIUrl = '';
@ -38,6 +42,7 @@
let querySettings = { let querySettings = {
template: '', template: '',
r: 0.0,
k: 4 k: 4
}; };
@ -115,6 +120,29 @@
} }
}; };
const rerankingModelUpdateHandler = async () => {
console.log('Update reranking model attempt:', rerankingModel);
updateRerankingModelLoading = true;
const res = await updateRerankingConfig(localStorage.token, {
reranking_model: rerankingModel
}).catch(async (error) => {
toast.error(error);
await setRerankingConfig();
return null;
});
updateRerankingModelLoading = false;
if (res) {
console.log('rerankingModelUpdateHandler:', res);
if (res.status === true) {
toast.success($i18n.t('Reranking model set to "{{reranking_model}}"', res), {
duration: 1000 * 10
});
}
}
};
const submitHandler = async () => { const submitHandler = async () => {
const res = await updateRAGConfig(localStorage.token, { const res = await updateRAGConfig(localStorage.token, {
pdf_extract_images: pdfExtractImages, pdf_extract_images: pdfExtractImages,
@ -138,6 +166,14 @@
} }
}; };
const setRerankingConfig = async () => {
const rerankingConfig = await getRerankingConfig(localStorage.token);
if (rerankingConfig) {
rerankingModel = rerankingConfig.reranking_model;
}
};
onMount(async () => { onMount(async () => {
const res = await getRAGConfig(localStorage.token); const res = await getRAGConfig(localStorage.token);
@ -149,6 +185,7 @@
} }
await setEmbeddingConfig(); await setEmbeddingConfig();
await setRerankingConfig();
querySettings = await getQuerySettings(localStorage.token); querySettings = await getQuerySettings(localStorage.token);
}); });
@ -349,6 +386,79 @@
<hr class=" dark:border-gray-700 my-3" /> <hr class=" dark:border-gray-700 my-3" />
<div class=" ">
<div class=" mb-2 text-sm font-medium">{$i18n.t('Update Reranking Model')}</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder={$i18n.t('Update reranking model (e.g. {{model}})', {
model: rerankingModel.slice(-40)
})}
bind:value={rerankingModel}
/>
</div>
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-850 dark:hover:bg-gray-800 dark:text-gray-100 rounded-lg transition"
on:click={() => {
rerankingModelUpdateHandler();
}}
disabled={updateRerankingModelLoading}
>
{#if updateRerankingModelLoading}
<div class="self-center">
<svg
class=" w-4 h-4"
viewBox="0 0 24 24"
fill="currentColor"
xmlns="http://www.w3.org/2000/svg"
><style>
.spinner_ajPY {
transform-origin: center;
animation: spinner_AtaB 0.75s infinite linear;
}
@keyframes spinner_AtaB {
100% {
transform: rotate(360deg);
}
}
</style><path
d="M12,1A11,11,0,1,0,23,12,11,11,0,0,0,12,1Zm0,19a8,8,0,1,1,8-8A8,8,0,0,1,12,20Z"
opacity=".25"
/><path
d="M10.14,1.16a11,11,0,0,0-9,8.92A1.59,1.59,0,0,0,2.46,12,1.52,1.52,0,0,0,4.11,10.7a8,8,0,0,1,6.66-6.61A1.42,1.42,0,0,0,12,2.69h0A1.57,1.57,0,0,0,10.14,1.16Z"
class="spinner_ajPY"
/></svg
>
</div>
{:else}
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
/>
<path
d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
/>
</svg>
{/if}
</button>
</div>
</div>
<div class="mt-2 mb-1 text-xs text-gray-400 dark:text-gray-500">
{$i18n.t(
'Note: If you choose a reranking model, it will use that to score and rerank instead of the embedding model.'
)}
</div>
<hr class=" dark:border-gray-700 my-3" />
<div class=" flex w-full justify-between"> <div class=" flex w-full justify-between">
<div class=" self-center text-xs font-medium"> <div class=" self-center text-xs font-medium">
{$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })} {$i18n.t('Scan for documents from {{path}}', { path: '/data/docs' })}
@ -473,6 +583,26 @@
</div> </div>
</div> </div>
<div class=" flex">
<div class=" flex w-full justify-between">
<div class="self-center text-xs font-medium flex-1">
{$i18n.t('Relevance Threshold')}
</div>
<div class="self-center p-3">
<input
class=" w-full rounded-lg py-1.5 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
type="number"
step="0.01"
placeholder={$i18n.t('Enter Relevance Threshold')}
bind:value={querySettings.r}
autocomplete="off"
min="0.0"
/>
</div>
</div>
</div>
<div> <div>
<div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div> <div class=" mb-2.5 text-sm font-medium">{$i18n.t('RAG Template')}</div>
<textarea <textarea