feat: move to native sentence_transformer

This commit is contained in:
Steven Kreitzer 2024-04-22 13:27:43 -05:00
parent 22c50f62cb
commit f3e5700d49
7 changed files with 153 additions and 268 deletions

View file

@ -1,13 +1,12 @@
import os
import re
import logging
from typing import List
import requests
from typing import List
from huggingface_hub import snapshot_download
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from apps.ollama.main import (
generate_ollama_embeddings,
GenerateEmbeddingsForm,
)
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -16,29 +15,12 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_doc(collection_name: str, query: str, k: int, embedding_function):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
return result
except Exception as e:
raise e
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k):
return merged_query_results
def query_collection(
collection_names: List[str], query: str, k: int, embedding_function
def query_embeddings_collection(
collection_names: List[str], query: str, query_embeddings, k: int
):
results = []
for collection_name in collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = []
log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names:
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
result = query_embeddings_doc(
collection_name=collection_name,
query=query,
query_embeddings=query_embeddings,
k=k,
)
results.append(result)
except:
@ -197,51 +156,38 @@ def rag_messages(
context = doc["content"]
else:
if embedding_engine == "":
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
query_embeddings = embedding_function.encode(query).tolist()
elif embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query=query,
query_embeddings=query_embeddings,
k=k,
)
else:
if embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query_embeddings=query_embeddings,
k=k,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query_embeddings=query_embeddings,
k=k,
)
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query=query,
query_embeddings=query_embeddings,
k=k,
)
except Exception as e:
log.exception(e)
@ -283,46 +229,6 @@ def rag_messages(
return messages
def get_embedding_model_path(
embedding_model: str, update_embedding_model: bool = False
):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
local_files_only = not update_embedding_model
snapshot_kwargs = {
"cache_dir": cache_dir,
"local_files_only": local_files_only,
}
log.debug(f"embedding_model: {embedding_model}")
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
# Inspiration from upstream sentence_transformers
if (
os.path.exists(embedding_model)
or ("\\" in embedding_model or embedding_model.count("/") > 1)
and local_files_only
):
# If fully qualified path exists, return input, else set repo_id
return embedding_model
elif "/" not in embedding_model:
# Set valid repo_id for model short-name
embedding_model = "sentence-transformers" + "/" + embedding_model
snapshot_kwargs["repo_id"] = embedding_model
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
return embedding_model_repo_path
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
):