fix: address comment in pr #1687

This commit is contained in:
Steven Kreitzer 2024-04-25 07:49:59 -05:00
parent d5f60b119c
commit c9c9660459
4 changed files with 92 additions and 43 deletions

View file

@ -1,3 +1,4 @@
import os
import logging
import requests
@ -8,6 +9,8 @@ from apps.ollama.main import (
GenerateEmbeddingsForm,
)
from huggingface_hub import snapshot_download
from langchain_core.documents import Document
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import (
@ -282,8 +285,6 @@ def rag_messages(
extracted_collections.extend(collection)
log.debug(f"relevant_contexts: {relevant_contexts}")
context_string = ""
for context in relevant_contexts:
items = context["documents"][0]
@ -319,6 +320,44 @@ def rag_messages(
return messages
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"embedding_model: {model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(model)
or ("\\" in model or model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return model
elif "/" not in model:
# Set valid repo_id for model short-name
model = "sentence-transformers" + "/" + model
snapshot_kwargs["repo_id"] = model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"model_repo_path: {model_repo_path}")
return model_repo_path
except Exception as e:
log.exception(f"Cannot determine model snapshot path: {e}")
return model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
):