forked from open-webui/open-webui
feat: openai embeddings integration
This commit is contained in:
parent
b48e73fa43
commit
b1b72441bb
6 changed files with 155 additions and 46 deletions
|
@ -659,7 +659,7 @@ def generate_ollama_embeddings(
|
|||
url_idx: Optional[int] = None,
|
||||
):
|
||||
|
||||
log.info("generate_ollama_embeddings", form_data)
|
||||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
|
@ -688,7 +688,7 @@ def generate_ollama_embeddings(
|
|||
|
||||
data = r.json()
|
||||
|
||||
log.info("generate_ollama_embeddings", data)
|
||||
log.info(f"generate_ollama_embeddings {data}")
|
||||
|
||||
if "embedding" in data:
|
||||
return data["embedding"]
|
||||
|
|
|
@ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
|
|||
docs = text_splitter.split_documents(data)
|
||||
|
||||
if len(docs) > 0:
|
||||
log.info("store_data_in_vector_db", "store_docs_in_vector_db")
|
||||
log.info(f"store_data_in_vector_db {docs}")
|
||||
return store_docs_in_vector_db(docs, collection_name, overwrite), None
|
||||
else:
|
||||
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
|
||||
|
@ -440,7 +440,7 @@ def store_text_in_vector_db(
|
|||
|
||||
|
||||
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
|
||||
log.info("store_docs_in_vector_db", docs, collection_name)
|
||||
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
@ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
collection.add(*batch)
|
||||
|
||||
else:
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
embeddings = [
|
||||
generate_ollama_embeddings(
|
||||
|
|
|
@ -6,9 +6,12 @@ import requests
|
|||
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
||||
|
||||
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
|
|||
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
|
||||
try:
|
||||
# if you use docker use the model from the environment variable
|
||||
log.info("query_embeddings_doc", query_embeddings)
|
||||
log.info(f"query_embeddings_doc {query_embeddings}")
|
||||
collection = CHROMA_CLIENT.get_collection(
|
||||
name=collection_name,
|
||||
)
|
||||
|
@ -118,7 +121,7 @@ def query_collection(
|
|||
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
|
||||
|
||||
results = []
|
||||
log.info("query_embeddings_collection", query_embeddings)
|
||||
log.info(f"query_embeddings_collection {query_embeddings}")
|
||||
|
||||
for collection_name in collection_names:
|
||||
try:
|
||||
|
@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str):
|
|||
return template
|
||||
|
||||
|
||||
def rag_messages(docs, messages, template, k, embedding_function):
|
||||
def rag_messages(
|
||||
docs,
|
||||
messages,
|
||||
template,
|
||||
k,
|
||||
embedding_engine,
|
||||
embedding_model,
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
):
|
||||
log.debug(f"docs: {docs}")
|
||||
|
||||
last_user_message_idx = None
|
||||
|
@ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
|
|||
context = None
|
||||
|
||||
try:
|
||||
if doc["type"] == "collection":
|
||||
context = query_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
elif doc["type"] == "text":
|
||||
|
||||
if doc["type"] == "text":
|
||||
context = doc["content"]
|
||||
else:
|
||||
context = query_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
if embedding_engine == "":
|
||||
if doc["type"] == "collection":
|
||||
context = query_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
else:
|
||||
context = query_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query=query,
|
||||
k=k,
|
||||
embedding_function=embedding_function,
|
||||
)
|
||||
|
||||
else:
|
||||
if embedding_engine == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": embedding_model,
|
||||
"prompt": query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif embedding_engine == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
key=openai_key,
|
||||
url=openai_url,
|
||||
)
|
||||
|
||||
if doc["type"] == "collection":
|
||||
context = query_embeddings_collection(
|
||||
collection_names=doc["collection_names"],
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
else:
|
||||
context = query_embeddings_doc(
|
||||
collection_name=doc["collection_name"],
|
||||
query_embeddings=query_embeddings,
|
||||
k=k,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
context = None
|
||||
|
|
|
@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|||
data["messages"],
|
||||
rag_app.state.RAG_TEMPLATE,
|
||||
rag_app.state.TOP_K,
|
||||
rag_app.state.RAG_EMBEDDING_ENGINE,
|
||||
rag_app.state.RAG_EMBEDDING_MODEL,
|
||||
rag_app.state.sentence_transformer_ef,
|
||||
rag_app.state.RAG_OPENAI_API_KEY,
|
||||
rag_app.state.RAG_OPENAI_API_BASE_URL,
|
||||
)
|
||||
del data["docs"]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue