forked from open-webui/open-webui
fix: address comment in pr #1687
This commit is contained in:
parent
d5f60b119c
commit
c9c9660459
4 changed files with 92 additions and 43 deletions
|
@ -92,10 +92,6 @@ async def get_ollama_api_urls(user=Depends(get_admin_user)):
|
||||||
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
|
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
|
||||||
|
|
||||||
|
|
||||||
def get_ollama_endpoint(url_idx: int = 0):
|
|
||||||
return app.state.OLLAMA_BASE_URLS[url_idx]
|
|
||||||
|
|
||||||
|
|
||||||
class UrlUpdateForm(BaseModel):
|
class UrlUpdateForm(BaseModel):
|
||||||
urls: List[str]
|
urls: List[str]
|
||||||
|
|
||||||
|
|
|
@ -39,8 +39,6 @@ import json
|
||||||
|
|
||||||
import sentence_transformers
|
import sentence_transformers
|
||||||
|
|
||||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
|
||||||
|
|
||||||
from apps.web.models.documents import (
|
from apps.web.models.documents import (
|
||||||
Documents,
|
Documents,
|
||||||
DocumentForm,
|
DocumentForm,
|
||||||
|
@ -48,6 +46,7 @@ from apps.web.models.documents import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from apps.rag.utils import (
|
from apps.rag.utils import (
|
||||||
|
get_model_path,
|
||||||
query_embeddings_doc,
|
query_embeddings_doc,
|
||||||
query_embeddings_function,
|
query_embeddings_function,
|
||||||
query_embeddings_collection,
|
query_embeddings_collection,
|
||||||
|
@ -60,6 +59,7 @@ from utils.misc import (
|
||||||
extract_folders_after_data_docs,
|
extract_folders_after_data_docs,
|
||||||
)
|
)
|
||||||
from utils.utils import get_current_user, get_admin_user
|
from utils.utils import get_current_user, get_admin_user
|
||||||
|
|
||||||
from config import (
|
from config import (
|
||||||
SRC_LOG_LEVELS,
|
SRC_LOG_LEVELS,
|
||||||
UPLOAD_DIR,
|
UPLOAD_DIR,
|
||||||
|
@ -68,8 +68,10 @@ from config import (
|
||||||
RAG_RELEVANCE_THRESHOLD,
|
RAG_RELEVANCE_THRESHOLD,
|
||||||
RAG_EMBEDDING_ENGINE,
|
RAG_EMBEDDING_ENGINE,
|
||||||
RAG_EMBEDDING_MODEL,
|
RAG_EMBEDDING_MODEL,
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_RERANKING_MODEL,
|
RAG_RERANKING_MODEL,
|
||||||
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
RAG_OPENAI_API_BASE_URL,
|
RAG_OPENAI_API_BASE_URL,
|
||||||
RAG_OPENAI_API_KEY,
|
RAG_OPENAI_API_KEY,
|
||||||
|
@ -87,13 +89,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
app.state.TOP_K = RAG_TOP_K
|
app.state.TOP_K = RAG_TOP_K
|
||||||
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
|
||||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||||
|
|
||||||
|
|
||||||
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||||
|
@ -104,18 +104,28 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
||||||
|
|
||||||
app.state.PDF_EXTRACT_IMAGES = False
|
app.state.PDF_EXTRACT_IMAGES = False
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
|
||||||
|
def update_embedding_model(
|
||||||
|
embedding_model: str,
|
||||||
|
update_model: bool = False,
|
||||||
|
):
|
||||||
|
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
get_model_path(embedding_model, update_model),
|
||||||
device=DEVICE_TYPE,
|
device=DEVICE_TYPE,
|
||||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
app.state.sentence_transformer_ef = None
|
app.state.sentence_transformer_ef = None
|
||||||
|
|
||||||
if not app.state.RAG_RERANKING_MODEL == "":
|
|
||||||
|
def update_reranking_model(
|
||||||
|
reranking_model: str,
|
||||||
|
update_model: bool = False,
|
||||||
|
):
|
||||||
|
if reranking_model:
|
||||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
||||||
app.state.RAG_RERANKING_MODEL,
|
get_model_path(reranking_model, update_model),
|
||||||
device=DEVICE_TYPE,
|
device=DEVICE_TYPE,
|
||||||
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
trust_remote_code=RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
||||||
)
|
)
|
||||||
|
@ -123,8 +133,19 @@ else:
|
||||||
app.state.sentence_transformer_rf = None
|
app.state.sentence_transformer_rf = None
|
||||||
|
|
||||||
|
|
||||||
|
update_embedding_model(
|
||||||
|
app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||||
|
)
|
||||||
|
|
||||||
|
update_reranking_model(
|
||||||
|
app.state.RAG_RERANKING_MODEL,
|
||||||
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
||||||
|
)
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=origins,
|
allow_origins=origins,
|
||||||
|
@ -200,15 +221,7 @@ async def update_embedding_config(
|
||||||
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||||
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
||||||
|
|
||||||
app.state.sentence_transformer_ef = None
|
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
|
||||||
else:
|
|
||||||
app.state.sentence_transformer_ef = (
|
|
||||||
sentence_transformers.SentenceTransformer(
|
|
||||||
app.state.RAG_EMBEDDING_MODEL,
|
|
||||||
device=DEVICE_TYPE,
|
|
||||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
|
@ -219,7 +232,6 @@ async def update_embedding_config(
|
||||||
"key": app.state.OPENAI_API_KEY,
|
"key": app.state.OPENAI_API_KEY,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Problem updating embedding model: {e}")
|
log.exception(f"Problem updating embedding model: {e}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -242,13 +254,7 @@ async def update_reranking_config(
|
||||||
try:
|
try:
|
||||||
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
|
||||||
|
|
||||||
if app.state.RAG_RERANKING_MODEL == "":
|
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
|
||||||
app.state.sentence_transformer_rf = None
|
|
||||||
else:
|
|
||||||
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
|
|
||||||
app.state.RAG_RERANKING_MODEL,
|
|
||||||
device=DEVICE_TYPE,
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": True,
|
"status": True,
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import logging
|
import logging
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
@ -8,6 +9,8 @@ from apps.ollama.main import (
|
||||||
GenerateEmbeddingsForm,
|
GenerateEmbeddingsForm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_community.retrievers import BM25Retriever
|
from langchain_community.retrievers import BM25Retriever
|
||||||
from langchain.retrievers import (
|
from langchain.retrievers import (
|
||||||
|
@ -282,8 +285,6 @@ def rag_messages(
|
||||||
|
|
||||||
extracted_collections.extend(collection)
|
extracted_collections.extend(collection)
|
||||||
|
|
||||||
log.debug(f"relevant_contexts: {relevant_contexts}")
|
|
||||||
|
|
||||||
context_string = ""
|
context_string = ""
|
||||||
for context in relevant_contexts:
|
for context in relevant_contexts:
|
||||||
items = context["documents"][0]
|
items = context["documents"][0]
|
||||||
|
@ -319,6 +320,44 @@ def rag_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_path(model: str, update_model: bool = False):
|
||||||
|
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||||
|
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||||
|
|
||||||
|
local_files_only = not update_model
|
||||||
|
|
||||||
|
snapshot_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.debug(f"embedding_model: {model}")
|
||||||
|
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||||
|
|
||||||
|
# Inspiration from upstream sentence_transformers
|
||||||
|
if (
|
||||||
|
os.path.exists(model)
|
||||||
|
or ("\\" in model or model.count("/") > 1)
|
||||||
|
and local_files_only
|
||||||
|
):
|
||||||
|
# If fully qualified path exists, return input, else set repo_id
|
||||||
|
return model
|
||||||
|
elif "/" not in model:
|
||||||
|
# Set valid repo_id for model short-name
|
||||||
|
model = "sentence-transformers" + "/" + model
|
||||||
|
|
||||||
|
snapshot_kwargs["repo_id"] = model
|
||||||
|
|
||||||
|
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||||
|
try:
|
||||||
|
model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||||
|
log.debug(f"model_repo_path: {model_repo_path}")
|
||||||
|
return model_repo_path
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(f"Cannot determine model snapshot path: {e}")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def generate_openai_embeddings(
|
def generate_openai_embeddings(
|
||||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||||
):
|
):
|
||||||
|
|
|
@ -430,6 +430,10 @@ RAG_EMBEDDING_MODEL = os.environ.get(
|
||||||
)
|
)
|
||||||
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
|
||||||
|
|
||||||
|
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||||
|
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
@ -438,6 +442,10 @@ RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
|
||||||
if not RAG_RERANKING_MODEL == "":
|
if not RAG_RERANKING_MODEL == "":
|
||||||
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
|
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
|
||||||
|
|
||||||
|
RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||||
|
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
||||||
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue