feat: move to native sentence_transformer

This commit is contained in:
Steven Kreitzer 2024-04-22 13:27:43 -05:00
parent 22c50f62cb
commit f3e5700d49
7 changed files with 153 additions and 268 deletions

View file

@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.121] - 2024-04-22
### Added
- **🛠️ Improved Embedding Model Support**: You can now use any embedding model `sentence_transformers` supports.
## [0.1.120] - 2024-04-20 ## [0.1.120] - 2024-04-20
### Added ### Added

View file

@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
######## WebUI frontend ######## ######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
@ -98,13 +98,13 @@ RUN pip3 install uv && \
# If you use CUDA the whisper and embedding model will be downloaded on first use # If you use CUDA the whisper and embedding model will be downloaded on first use
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
else \ else \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
fi fi

View file

@ -13,7 +13,6 @@ import os, shutil, logging, re
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from chromadb.utils import embedding_functions
from chromadb.utils.batch_utils import create_batches from chromadb.utils.batch_utils import create_batches
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
@ -38,6 +37,7 @@ import mimetypes
import uuid import uuid
import json import json
import sentence_transformers
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
@ -48,11 +48,8 @@ from apps.web.models.documents import (
) )
from apps.rag.utils import ( from apps.rag.utils import (
query_doc,
query_embeddings_doc, query_embeddings_doc,
query_collection,
query_embeddings_collection, query_embeddings_collection,
get_embedding_model_path,
generate_openai_embeddings, generate_openai_embeddings,
) )
@ -69,7 +66,7 @@ from config import (
DOCS_DIR, DOCS_DIR,
RAG_EMBEDDING_ENGINE, RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
RAG_OPENAI_API_BASE_URL, RAG_OPENAI_API_BASE_URL,
RAG_OPENAI_API_KEY, RAG_OPENAI_API_KEY,
DEVICE_TYPE, DEVICE_TYPE,
@ -101,14 +98,11 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
if app.state.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
embedding_functions.SentenceTransformerEmbeddingFunction( app.state.RAG_EMBEDDING_MODEL,
model_name=get_embedding_model_path(
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
),
device=DEVICE_TYPE, device=DEVICE_TYPE,
) trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
) )
@ -185,13 +179,10 @@ async def update_embedding_config(
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key app.state.OPENAI_API_KEY = form_data.openai_config.key
else: else:
sentence_transformer_ef = ( sentence_transformer_ef = sentence_transformers.SentenceTransformer(
embedding_functions.SentenceTransformerEmbeddingFunction( app.state.RAG_EMBEDDING_MODEL,
model_name=get_embedding_model_path(
form_data.embedding_model, True
),
device=DEVICE_TYPE, device=DEVICE_TYPE,
) trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
) )
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef app.state.sentence_transformer_ef = sentence_transformer_ef
@ -294,17 +285,12 @@ def query_doc_handler(
form_data: QueryDocForm, form_data: QueryDocForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": if app.state.RAG_EMBEDDING_ENGINE == "":
return query_doc( query_embeddings = app.state.sentence_transformer_ef.encode(
collection_name=form_data.collection_name, form_data.query
query=form_data.query, ).tolist()
k=form_data.k if form_data.k else app.state.TOP_K, elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
**{ **{
@ -323,6 +309,7 @@ def query_doc_handler(
return query_embeddings_doc( return query_embeddings_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
@ -348,15 +335,10 @@ def query_collection_handler(
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "": if app.state.RAG_EMBEDDING_ENGINE == "":
return query_collection( query_embeddings = app.state.sentence_transformer_ef.encode(
collection_names=form_data.collection_names, form_data.query
query=form_data.query, ).tolist()
k=form_data.k if form_data.k else app.state.TOP_K, elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
**{ **{
@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"store_docs_in_vector_db {docs} {collection_name}") log.info(f"store_docs_in_vector_db {docs} {collection_name}")
texts = [doc.page_content for doc in docs] texts = [doc.page_content for doc in docs]
texts = list(map(lambda x: x.replace("\n", " "), texts))
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
@ -454,25 +438,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name) CHROMA_CLIENT.delete_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "":
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
documents=texts,
):
collection.add(*batch)
else:
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "":
embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
embeddings = [ embeddings = [
generate_ollama_embeddings( generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(

View file

@ -1,13 +1,12 @@
import os
import re
import logging import logging
from typing import List
import requests import requests
from typing import List
from huggingface_hub import snapshot_download from apps.ollama.main import (
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm generate_ollama_embeddings,
GenerateEmbeddingsForm,
)
from config import SRC_LOG_LEVELS, CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -16,29 +15,12 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_doc(collection_name: str, query: str, k: int, embedding_function): def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
return result
except Exception as e:
raise e
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try: try:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}") log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(name=collection_name)
name=collection_name,
)
result = collection.query( result = collection.query(
query_embeddings=[query_embeddings], query_embeddings=[query_embeddings],
n_results=k, n_results=k,
@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k):
return merged_query_results return merged_query_results
def query_collection( def query_embeddings_collection(
collection_names: List[str], query: str, k: int, embedding_function collection_names: List[str], query: str, query_embeddings, k: int
): ):
results = []
for collection_name in collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = [] results = []
log.info(f"query_embeddings_collection {query_embeddings}") log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names: for collection_name in collection_names:
try: try:
collection = CHROMA_CLIENT.get_collection(name=collection_name) result = query_embeddings_doc(
collection_name=collection_name,
result = collection.query( query=query,
query_embeddings=[query_embeddings], query_embeddings=query_embeddings,
n_results=k, k=k,
) )
results.append(result) results.append(result)
except: except:
@ -197,23 +156,8 @@ def rag_messages(
context = doc["content"] context = doc["content"]
else: else:
if embedding_engine == "": if embedding_engine == "":
if doc["type"] == "collection": query_embeddings = embedding_function.encode(query).tolist()
context = query_collection( elif embedding_engine == "ollama":
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
if embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
**{ **{
@ -233,12 +177,14 @@ def rag_messages(
if doc["type"] == "collection": if doc["type"] == "collection":
context = query_embeddings_collection( context = query_embeddings_collection(
collection_names=doc["collection_names"], collection_names=doc["collection_names"],
query=query,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=k, k=k,
) )
else: else:
context = query_embeddings_doc( context = query_embeddings_doc(
collection_name=doc["collection_name"], collection_name=doc["collection_name"],
query=query,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=k, k=k,
) )
@ -283,46 +229,6 @@ def rag_messages(
return messages return messages
def get_embedding_model_path(
embedding_model: str, update_embedding_model: bool = False
):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_embedding_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"embedding_model: {embedding_model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(embedding_model)
or ("\\" in embedding_model or embedding_model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return embedding_model
elif "/" not in embedding_model:
# Set valid repo_id for model short-name
embedding_model = "sentence-transformers" + "/" + embedding_model
snapshot_kwargs["repo_id"] = embedding_model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
return embedding_model_repo_path
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model
def generate_openai_embeddings( def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1" model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
): ):

View file

@ -411,18 +411,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")

View file

@ -25,6 +25,7 @@ apscheduler
google-generativeai google-generativeai
langchain langchain
langchain-chroma
langchain-community langchain-community
fake_useragent fake_useragent
chromadb chromadb
@ -43,6 +44,7 @@ opencv-python-headless
rapidocr-onnxruntime rapidocr-onnxruntime
fpdf2 fpdf2
rank_bm25
faster-whisper faster-whisper

View file

@ -180,7 +180,7 @@
} }
}} }}
> >
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option> <option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
<option value="ollama">{$i18n.t('Ollama')}</option> <option value="ollama">{$i18n.t('Ollama')}</option>
<option value="openai">{$i18n.t('OpenAI')}</option> <option value="openai">{$i18n.t('OpenAI')}</option>
</select> </select>