forked from open-webui/open-webui
feat: move to native sentence_transformer
This commit is contained in:
parent
22c50f62cb
commit
f3e5700d49
7 changed files with 153 additions and 268 deletions
|
@ -1,13 +1,12 @@
|
|||
import os
|
||||
import re
|
||||
import logging
|
||||
from typing import List
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
||||
|
||||
from apps.ollama.main import (
|
||||
generate_ollama_embeddings,
|
||||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
|
@ -16,29 +15,12 @@ log = logging.getLogger(__name__)
|
|||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
result = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=k,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
||||
def query_embeddings_doc(collection_name: str, query: str, query_embeddings, k: int):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
log.info(f"query_embeddings_doc {query_embeddings}")
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
)
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
|
@ -95,43 +77,20 @@ def merge_and_sort_query_results(query_results, k):
|
|||
return merged_query_results
|
||||
|
||||
|
||||
def query_collection(
|
||||
collection_names: List[str], query: str, k: int, embedding_function
|
||||
def query_embeddings_collection(
|
||||
collection_names: List[str], query: str, query_embeddings, k: int
|
||||
):
|
||||
|
||||
results = []
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
result = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=k,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
pass
|
||||
|
||||
return merge_and_sort_query_results(results, k)
|
||||
|
||||
|
||||
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
||||
|
||||
results = []
|
||||
log.info(f"query_embeddings_collection {query_embeddings}")
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
collection = CHROMA_CLIENT.get_collection(name=collection_name)
|
||||
|
||||
result = collection.query(
|
||||
query_embeddings=[query_embeddings],
|
||||
n_results=k,
|
||||
result = query_embeddings_doc(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
results.append(result)
|
||||
except:
|
||||
|
@ -197,51 +156,38 @@ def rag_messages(
|
|||
context = doc["content"]
|
||||
else:
|
||||
if embedding_engine == "":
|
||||
if doc["type"] == "collection":
|
||||
context = query_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
else:
|
||||
context = query_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
query_embeddings = embedding_function.encode(query).tolist()
|
||||
elif embedding_engine == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
|
||||
if doc["type"] == "collection":
|
||||
context = query_embeddings_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
else:
|
||||
if embedding_engine == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
|
||||
if doc["type"] == "collection":
|
||||
context = query_embeddings_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
else:
|
||||
context = query_embeddings_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
context = query_embeddings_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query=query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
@ -283,46 +229,6 @@ def rag_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_embedding_model_path(
|
||||
embedding_model: str, update_embedding_model: bool = False
|
||||
):
|
||||
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
|
||||
cache_dir = os.getenv("SENTENCE_TRANSFORMERS_HOME")
|
||||
|
||||
local_files_only = not update_embedding_model
|
||||
|
||||
snapshot_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
log.debug(f"embedding_model: {embedding_model}")
|
||||
log.debug(f"snapshot_kwargs: {snapshot_kwargs}")
|
||||
|
||||
# Inspiration from upstream sentence_transformers
|
||||
if (
|
||||
os.path.exists(embedding_model)
|
||||
or ("\\" in embedding_model or embedding_model.count("/") > 1)
|
||||
and local_files_only
|
||||
):
|
||||
# If fully qualified path exists, return input, else set repo_id
|
||||
return embedding_model
|
||||
elif "/" not in embedding_model:
|
||||
# Set valid repo_id for model short-name
|
||||
embedding_model = "sentence-transformers" + "/" + embedding_model
|
||||
|
||||
snapshot_kwargs["repo_id"] = embedding_model
|
||||
|
||||
# Attempt to query the huggingface_hub library to determine the local path and/or to update
|
||||
try:
|
||||
embedding_model_repo_path = snapshot_download(**snapshot_kwargs)
|
||||
log.debug(f"embedding_model_repo_path: {embedding_model_repo_path}")
|
||||
return embedding_model_repo_path
|
||||
except Exception as e:
|
||||
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||
return embedding_model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue