feat: external embeddings support

This commit is contained in:
Timothy J. Baek 2024-04-14 17:55:00 -04:00
parent 8b10b058e5
commit 2952e61167
6 changed files with 312 additions and 118 deletions

View file

@ -654,6 +654,55 @@ async def generate_embeddings(
)
def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
):
if url_idx == None:
model = form_data.model
if ":" not in model:
model = f"{model}:latest"
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
data = r.json()
if "embedding" in data:
return data["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise error_detail
class GenerateCompletionForm(BaseModel):
model: str
prompt: str

View file

@ -39,13 +39,21 @@ import uuid
import json
from apps.ollama.main import generate_ollama_embeddings
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
)
from apps.rag.utils import query_doc, query_collection, get_embedding_model_path
from apps.rag.utils import (
query_doc,
query_embeddings_doc,
query_collection,
query_embeddings_collection,
get_embedding_model_path,
)
from utils.misc import (
calculate_sha256,
@ -58,6 +66,7 @@ from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
RAG_EMBEDDING_ENGINE,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
@ -74,17 +83,20 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
app.state.PDF_EXTRACT_IMAGES = False
app.state.TOP_K = 4
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.PDF_EXTRACT_IMAGES = False
app.state.TOP_K = 4
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=get_embedding_model_path(
@ -121,6 +133,7 @@ async def get_status():
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@ -252,12 +265,23 @@ def query_doc_handler(
):
try:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query}
)
return query_embeddings_doc(
collection_name=form_data.collection_name,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e:
log.exception(e)
raise HTTPException(
@ -277,12 +301,30 @@ def query_collection_handler(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
try:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": form_data.query}
)
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.post("/web")
@ -317,6 +359,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.split_documents(data)
if len(docs) > 0:
@ -337,7 +380,9 @@ def store_text_in_vector_db(
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
async def store_docs_in_vector_db(
docs, collection_name, overwrite: bool = False
) -> bool:
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
@ -349,20 +394,36 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
collection = CHROMA_CLIENT.create_collection(name=collection_name)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
documents=texts,
):
collection.add(*batch)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=[
generate_ollama_embeddings(
{"model": RAG_EMBEDDING_MODEL, "prompt": text}
)
for text in texts
],
):
collection.add(*batch)
else:
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
return True
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
documents=texts,
):
collection.add(*batch)
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":

View file

@ -2,6 +2,9 @@ import os
import re
import logging
from typing import List
import requests
from huggingface_hub import snapshot_download
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
@ -26,6 +29,21 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
raise e
def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
return result
except Exception as e:
raise e
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
@ -96,6 +114,24 @@ def query_collection(
return merge_and_sort_query_results(results, k)
def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
results = []
for collection_name in collection_names:
try:
collection = CHROMA_CLIENT.get_collection(name=collection_name)
result = collection.query(
query_embeddings=[query_embeddings],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)