feat: openai embeddings integration

This commit is contained in:
Timothy J. Baek 2024-04-14 19:48:15 -04:00
parent b48e73fa43
commit b1b72441bb
6 changed files with 155 additions and 46 deletions

View file

@ -6,9 +6,12 @@ import requests
from huggingface_hub import snapshot_download
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
log.info("query_embeddings_doc", query_embeddings)
log.info(f"query_embeddings_doc {query_embeddings}")
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
@ -118,7 +121,7 @@ def query_collection(
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = []
log.info("query_embeddings_collection", query_embeddings)
log.info(f"query_embeddings_collection {query_embeddings}")
for collection_name in collection_names:
try:
@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str):
return template
def rag_messages(docs, messages, template, k, embedding_function):
def rag_messages(
docs,
messages,
template,
k,
embedding_engine,
embedding_model,
embedding_function,
openai_key,
openai_url,
):
log.debug(f"docs: {docs}")
last_user_message_idx = None
@ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
context = None
try:
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
elif doc["type"] == "text":
if doc["type"] == "text":
context = doc["content"]
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
if embedding_engine == "":
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
if embedding_engine == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{
"model": embedding_model,
"prompt": query,
}
)
)
elif embedding_engine == "openai":
query_embeddings = generate_openai_embeddings(
model=embedding_model,
text=query,
key=openai_key,
url=openai_url,
)
if doc["type"] == "collection":
context = query_embeddings_collection(
collection_names=doc["collection_names"],
query_embeddings=query_embeddings,
k=k,
)
else:
context = query_embeddings_doc(
collection_name=doc["collection_name"],
query_embeddings=query_embeddings,
k=k,
)
except Exception as e:
log.exception(e)
context = None