feat: openai embeddings support

This commit is contained in:
Timothy J. Baek 2024-04-14 19:15:39 -04:00
parent 36ce157907
commit b48e73fa43
2 changed files with 127 additions and 54 deletions

View file

@ -53,6 +53,7 @@ from apps.rag.utils import (
query_collection,
query_embeddings_collection,
get_embedding_model_path,
generate_openai_embeddings,
)
from utils.misc import (
@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
app.state.PDF_EXTRACT_IMAGES = False
@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
}
class OpenAIConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str
embedding_model: str
@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None
if form_data.openai_config != None:
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
else:
sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
@ -183,6 +198,10 @@ async def update_embedding_config(
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
}
except Exception as e:
@ -275,6 +294,14 @@ def query_doc_handler(
):
try:
if app.state.RAG_EMBEDDING_ENGINE == "":
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
@ -284,19 +311,20 @@ def query_doc_handler(
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_doc(
collection_name=form_data.collection_name,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e:
log.exception(e)
raise HTTPException(
@ -317,6 +345,15 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
if app.state.RAG_EMBEDDING_ENGINE == "":
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm(
@ -326,19 +363,20 @@ def query_collection_handler(
}
)
)
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_collection(
collection_names=form_data.collection_names,
query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K,
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e:
log.exception(e)
raise HTTPException(
@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
collection = CHROMA_CLIENT.create_collection(name=collection_name)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=[
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
],
):
collection.add(*batch)
else:
if app.state.RAG_EMBEDDING_ENGINE == "":
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
@ -446,6 +467,35 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
):
collection.add(*batch)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
embeddings = [
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
for text in texts
]
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
):
collection.add(*batch)
return True
except Exception as e:
log.exception(e)

View file

@ -269,3 +269,26 @@ def get_embedding_model_path(
except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com"
):
try:
r = requests.post(
f"{url}/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None