feat: external embeddings support

This commit is contained in:
Timothy J. Baek 2024-04-14 17:55:00 -04:00
parent 8b10b058e5
commit 2952e61167
6 changed files with 312 additions and 118 deletions

View file

@ -2,6 +2,9 @@ import os
import re
import logging
from typing import List
import requests
from huggingface_hub import snapshot_download
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -26,6 +29,21 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
raise e
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
return result
except Exception as e:
raise e
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
@ -96,6 +114,24 @@ def query_collection(
return merge_and_sort_query_results(results, k)
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = []
for collection_name in collection_names:
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)