forked from open-webui/open-webui
feat: move to native sentence_transformer
This commit is contained in:
parent
22c50f62cb
commit
f3e5700d49
7 changed files with 153 additions and 268 deletions
|
@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [0.1.121] - 2024-04-22
|
||||
|
||||
### Added
|
||||
|
||||
- **🛠️ Improved Embedding Model Support**: You can now use any embedding model `sentence_transformers` supports.
|
||||
|
||||
## [0.1.120] - 2024-04-20
|
||||
|
||||
### Added
|
||||
|
|
12
Dockerfile
12
Dockerfile
|
@ -8,8 +8,8 @@ ARG USE_CUDA_VER=cu121
|
|||
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
|
||||
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
|
||||
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
|
||||
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||
ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2
|
||||
# IMPORTANT: If you change the default model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
|
||||
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
||||
|
||||
######## WebUI frontend ########
|
||||
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
|
||||
|
@ -98,13 +98,13 @@ RUN pip3 install uv && \
|
|||
# If you use CUDA the whisper and embedding model will be downloaded on first use
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
|
||||
uv pip install --system -r requirements.txt --no-cache-dir && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
|
||||
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
|
||||
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
||||
else \
|
||||
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
|
||||
uv pip install --system -r requirements.txt --no-cache-dir && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
|
||||
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
|
||||
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
|
||||
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
|
||||
fi
|
||||
|
||||
|
||||
|
|
|
@ -13,7 +13,6 @@ import os, shutil, logging, re
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
|
@ -38,6 +37,7 @@ import mimetypes
|
|||
import uuid
|
||||
import json
|
||||
|
||||
import sentence_transformers
|
||||
|
||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
||||
|
||||
|
@ -48,11 +48,8 @@ from apps.web.models.documents import (
|
|||
)
|
||||
|
||||
from apps.rag.utils import (
|
||||
query_doc,
|
||||
query_embeddings_doc,
|
||||
query_collection,
|
||||
query_embeddings_collection,
|
||||
get_embedding_model_path,
|
||||
generate_openai_embeddings,
|
||||
)
|
||||
|
||||
|
@ -69,7 +66,7 @@ from config import (
|
|||
DOCS_DIR,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
RAG_OPENAI_API_KEY,
|
||||
DEVICE_TYPE,
|
||||
|
@ -101,14 +98,11 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
|||
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=get_embedding_model_path(
|
||||
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
|
||||
),
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
|
||||
|
||||
|
@ -185,13 +179,10 @@ async def update_embedding_config(
|
|||
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
||||
else:
|
||||
sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=get_embedding_model_path(
|
||||
form_data.embedding_model, True
|
||||
),
|
||||
sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = sentence_transformer_ef
|
||||
|
@ -294,17 +285,12 @@ def query_doc_handler(
|
|||
form_data: QueryDocForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = app.state.sentence_transformer_ef.encode(
|
||||
form_data.query
|
||||
).tolist()
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
|
@ -323,6 +309,7 @@ def query_doc_handler(
|
|||
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
|
@ -348,15 +335,10 @@ def query_collection_handler(
|
|||
):
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
else:
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = app.state.sentence_transformer_ef.encode(
|
||||
form_data.query
|
||||
).tolist()
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
|
@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
try:
|
||||
|
@ -454,25 +438,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
documents=texts,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
else:
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
embeddings = [
|
||||
generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
|
|
|
@ -1,13 +1,12 @@
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
||||
|
||||
from apps.ollama.main import (
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
|
@ -16,29 +15,12 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
result = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=k,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
||||
def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
log.info(f"query_embeddings_doc {query_embeddings}")
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
)
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
|
@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k):
|
|||
return merged_query_results
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: List[str], query: str, k: int, embedding_function
|
||||
def query_embeddings_collection(
|
||||
collection_names: List[str], query: str, query_embeddings, k: int
|
||||
):
|
||||
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
result = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=k,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
pass
|
||||
|
||||
return merge_and_sort_query_results(results, k)
|
||||
|
||||
|
||||
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
||||
|
||||
results = []
|
||||
log.info(f"query_embeddings_collection {query_embeddings}")
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
result = query_embeddings_doc(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
|
@ -197,23 +156,8 @@ def rag_messages(
|
|||
context = doc["content"]
|
||||
else:
|
||||
if embedding_engine == "":
|
||||
if doc["type"] == "collection":
|
||||
context = query_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
else:
|
||||
context = query_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
else:
|
||||
if embedding_engine == "ollama":
|
||||
query_embeddings = embedding_function.encode(query).tolist()
|
||||
elif embedding_engine == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
|
@ -233,12 +177,14 @@ def rag_messages(
|
|||
if doc["type"] == "collection":
|
||||
context = query_embeddings_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
else:
|
||||
context = query_embeddings_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query=query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
|
@ -283,46 +229,6 @@ def rag_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_embedding_model_path(
|
||||
embedding_model: str, update_embedding_model: bool = False
|
||||
):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_embedding_model
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"embedding_model: {embedding_model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(embedding_model)
|
||||
or ("\\" in embedding_model or embedding_model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return embedding_model
|
||||
elif "/" not in embedding_model:
|
||||
# Set valid repo_id for model short-name
|
||||
embedding_model = "sentence-transformers" + "/" + embedding_model
|
||||
|
||||
snapshot_kwargs["repo_id"] = embedding_model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
|
||||
return embedding_model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||
return embedding_model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
):
|
||||
|
|
|
@ -411,18 +411,19 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
|
|||
####################################
|
||||
|
||||
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
|
||||
|
||||
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
|
||||
|
||||
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||
RAG_EMBEDDING_MODEL = os.environ.get(
|
||||
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
|
||||
)
|
||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
||||
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
|
||||
|
||||
|
|
|
@ -25,6 +25,7 @@ apscheduler
|
|||
google-generativeai
|
||||
|
||||
langchain
|
||||
langchain-chroma
|
||||
langchain-community
|
||||
fake_useragent
|
||||
chromadb
|
||||
|
@ -43,6 +44,7 @@ opencv-python-headless
|
|||
rapidocr-onnxruntime
|
||||
|
||||
fpdf2
|
||||
rank_bm25
|
||||
|
||||
faster-whisper
|
||||
|
||||
|
|
|
@ -180,7 +180,7 @@
|
|||
}
|
||||
}}
|
||||
>
|
||||
<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
|
||||
<option value="">{$i18n.t('Default (SentenceTransformers)')}</option>
|
||||
<option value="ollama">{$i18n.t('Ollama')}</option>
|
||||
<option value="openai">{$i18n.t('OpenAI')}</option>
|
||||
</select>
|
||||
|
|
Loading…
Reference in a new issue