feat: openai embeddings support

This commit is contained in:
Timothy J. Baek 2024-04-14 19:15:39 -04:00
parent 36ce157907
commit b48e73fa43
2 changed files with 127 additions and 54 deletions

View file

@ -53,6 +53,7 @@ from apps.rag.utils import (
query_collection, query_collection,
query_embeddings_collection, query_embeddings_collection,
get_embedding_model_path, get_embedding_model_path,
generate_openai_embeddings,
) )
from utils.misc import ( from utils.misc import (
@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
app.state.RAG_OPENAI_API_KEY = ""
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
} }
class OpenAIConfigForm(BaseModel):
url: str
key: str
class EmbeddingModelUpdateForm(BaseModel): class EmbeddingModelUpdateForm(BaseModel):
openai_config: Optional[OpenAIConfigForm] = None
embedding_engine: str embedding_engine: str
embedding_model: str embedding_model: str
@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_config( async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = None app.state.sentence_transformer_ef = None
if form_data.openai_config != None:
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
else: else:
sentence_transformer_ef = ( sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
@ -183,6 +198,10 @@ async def update_embedding_config(
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY,
},
} }
except Exception as e: except Exception as e:
@ -275,6 +294,14 @@ def query_doc_handler(
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "":
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
@ -284,19 +311,20 @@ def query_doc_handler(
} }
) )
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_doc( return query_embeddings_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -317,6 +345,15 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.RAG_EMBEDDING_ENGINE == "":
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "ollama":
query_embeddings = generate_ollama_embeddings( query_embeddings = generate_ollama_embeddings(
GenerateEmbeddingsForm( GenerateEmbeddingsForm(
@ -326,19 +363,20 @@ def query_collection_handler(
} }
) )
) )
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
return query_embeddings_collection( return query_embeddings_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query_embeddings=query_embeddings, query_embeddings=query_embeddings,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.TOP_K,
) )
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
raise HTTPException( raise HTTPException(
@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
log.info(f"deleting existing collection {collection_name}") log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name) CHROMA_CLIENT.delete_collection(name=collection_name)
if app.state.RAG_EMBEDDING_ENGINE == "ollama": if app.state.RAG_EMBEDDING_ENGINE == "":
collection = CHROMA_CLIENT.create_collection(name=collection_name)
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=[
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
],
):
collection.add(*batch)
else:
collection = CHROMA_CLIENT.create_collection( collection = CHROMA_CLIENT.create_collection(
name=collection_name, name=collection_name,
@ -446,6 +467,35 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
): ):
collection.add(*batch) collection.add(*batch)
else:
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
embeddings = [
generate_ollama_embeddings(
GenerateEmbeddingsForm(
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
)
)
for text in texts
]
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
embeddings = [
generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL,
text=text,
key=app.state.RAG_OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL,
)
for text in texts
]
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
):
collection.add(*batch)
return True return True
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)

View file

@ -269,3 +269,26 @@ def get_embedding_model_path(
except Exception as e: except Exception as e:
log.exception(f"Cannot determine embedding model snapshot path: {e}") log.exception(f"Cannot determine embedding model snapshot path: {e}")
return embedding_model return embedding_model
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com"
):
try:
r = requests.post(
f"{url}/v1/embeddings",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
else:
raise "Something went wrong :/"
except Exception as e:
print(e)
return None