forked from open-webui/open-webui
fix: address comment in pr #1687
This commit is contained in:
parent
d5f60b119c
commit
c9c9660459
4 changed files with 92 additions and 43 deletions
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import logging
|
||||
import requests
|
||||
|
||||
|
@ -8,6 +9,8 @@ from apps.ollama.main import (
|
|||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_community.retrievers import BM25Retriever
|
||||
from langchain.retrievers import (
|
||||
|
@ -282,8 +285,6 @@ def rag_messages(
|
|||
|
||||
extracted_collections.extend(collection)
|
||||
|
||||
log.debug(f"relevant_contexts: {relevant_contexts}")
|
||||
|
||||
context_string = ""
|
||||
for context in relevant_contexts:
|
||||
items = context["documents"][0]
|
||||
|
@ -319,6 +320,44 @@ def rag_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_model
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"embedding_model: {model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(model)
|
||||
or ("\\" in model or model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return model
|
||||
elif "/" not in model:
|
||||
# Set valid repo_id for model short-name
|
||||
model = "sentence-transformers" + "/" + model
|
||||
|
||||
snapshot_kwargs["repo_id"] = model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"model_repo_path: {model_repo_path}")
|
||||
return model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine model snapshot path: {e}")
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue