forked from open-webui/open-webui
feat: openai embeddings support
This commit is contained in:
parent
36ce157907
commit
b48e73fa43
2 changed files with 127 additions and 54 deletions
|
@ -53,6 +53,7 @@ from apps.rag.utils import (
|
|||
query_collection,
|
||||
query_embeddings_collection,
|
||||
get_embedding_model_path,
|
||||
generate_openai_embeddings,
|
||||
)
|
||||
|
||||
from utils.misc import (
|
||||
|
@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
|||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
|
||||
app.state.RAG_OPENAI_API_KEY = ""
|
||||
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
|
||||
|
@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
|||
"status": True,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"openai_config": {
|
||||
"url": app.state.RAG_OPENAI_API_BASE_URL,
|
||||
"key": app.state.RAG_OPENAI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
openai_config: Optional[OpenAIConfigForm] = None
|
||||
embedding_engine: str
|
||||
embedding_model: str
|
||||
|
||||
|
@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
|
|||
async def update_embedding_config(
|
||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
|
||||
log.info(
|
||||
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||
)
|
||||
|
||||
try:
|
||||
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = None
|
||||
|
||||
if form_data.openai_config != None:
|
||||
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
|
||||
else:
|
||||
sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
|
@ -183,6 +198,10 @@ async def update_embedding_config(
|
|||
"status": True,
|
||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"openai_config": {
|
||||
"url": app.state.RAG_OPENAI_API_BASE_URL,
|
||||
"key": app.state.RAG_OPENAI_API_KEY,
|
||||
},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
@ -275,28 +294,37 @@ def query_doc_handler(
|
|||
):
|
||||
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=form_data.query,
|
||||
key=app.state.RAG_OPENAI_API_KEY,
|
||||
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
|
@ -317,28 +345,38 @@ def query_collection_handler(
|
|||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return query_embeddings_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
else:
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=form_data.query,
|
||||
key=app.state.RAG_OPENAI_API_KEY,
|
||||
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return query_embeddings_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
|
@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=[
|
||||
generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
|
||||
)
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
):
|
||||
collection.add(*batch)
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
|
@ -446,7 +467,36 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
):
|
||||
collection.add(*batch)
|
||||
|
||||
return True
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
embeddings = [
|
||||
generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
||||
)
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
embeddings = [
|
||||
generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=text,
|
||||
key=app.state.RAG_OPENAI_API_KEY,
|
||||
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
if e.__class__.__name__ == "UniqueConstraintError":
|
||||
|
|
|
@ -269,3 +269,26 @@ def get_embedding_model_path(
|
|||
except Exception as e:
|
||||
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||
return embedding_model
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/v1/embeddings",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": text, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return data["data"][0]["embedding"]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
|
Loading…
Reference in a new issue