forked from open-webui/open-webui
feat: openai embeddings support
This commit is contained in:
parent
36ce157907
commit
b48e73fa43
2 changed files with 127 additions and 54 deletions
|
@ -53,6 +53,7 @@ from apps.rag.utils import (
|
||||||
query_collection,
|
query_collection,
|
||||||
query_embeddings_collection,
|
query_embeddings_collection,
|
||||||
get_embedding_model_path,
|
get_embedding_model_path,
|
||||||
|
generate_openai_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
from utils.misc import (
|
from utils.misc import (
|
||||||
|
@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||||
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
|
|
||||||
|
app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
|
||||||
|
app.state.RAG_OPENAI_API_KEY = ""
|
||||||
|
|
||||||
app.state.PDF_EXTRACT_IMAGES = False
|
app.state.PDF_EXTRACT_IMAGES = False
|
||||||
|
|
||||||
|
@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
"openai_config": {
|
||||||
|
"url": app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
"key": app.state.RAG_OPENAI_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIConfigForm(BaseModel):
|
||||||
|
url: str
|
||||||
|
key: str
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingModelUpdateForm(BaseModel):
|
class EmbeddingModelUpdateForm(BaseModel):
|
||||||
|
openai_config: Optional[OpenAIConfigForm] = None
|
||||||
embedding_engine: str
|
embedding_engine: str
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
|
|
||||||
|
@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
|
||||||
async def update_embedding_config(
|
async def update_embedding_config(
|
||||||
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
log.info(
|
log.info(
|
||||||
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||||
app.state.sentence_transformer_ef = None
|
app.state.sentence_transformer_ef = None
|
||||||
|
|
||||||
|
if form_data.openai_config != None:
|
||||||
|
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||||
|
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
|
||||||
else:
|
else:
|
||||||
sentence_transformer_ef = (
|
sentence_transformer_ef = (
|
||||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
|
@ -183,6 +198,10 @@ async def update_embedding_config(
|
||||||
"status": True,
|
"status": True,
|
||||||
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
|
||||||
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
"openai_config": {
|
||||||
|
"url": app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
"key": app.state.RAG_OPENAI_API_KEY,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -275,6 +294,14 @@ def query_doc_handler(
|
||||||
):
|
):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
|
return query_doc(
|
||||||
|
collection_name=form_data.collection_name,
|
||||||
|
query=form_data.query,
|
||||||
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
|
)
|
||||||
|
else:
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||||
query_embeddings = generate_ollama_embeddings(
|
query_embeddings = generate_ollama_embeddings(
|
||||||
GenerateEmbeddingsForm(
|
GenerateEmbeddingsForm(
|
||||||
|
@ -284,19 +311,20 @@ def query_doc_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||||
|
query_embeddings = generate_openai_embeddings(
|
||||||
|
model=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
text=form_data.query,
|
||||||
|
key=app.state.RAG_OPENAI_API_KEY,
|
||||||
|
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
return query_embeddings_doc(
|
return query_embeddings_doc(
|
||||||
collection_name=form_data.collection_name,
|
collection_name=form_data.collection_name,
|
||||||
query_embeddings=query_embeddings,
|
query_embeddings=query_embeddings,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return query_doc(
|
|
||||||
collection_name=form_data.collection_name,
|
|
||||||
query=form_data.query,
|
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
||||||
embedding_function=app.state.sentence_transformer_ef,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -317,6 +345,15 @@ def query_collection_handler(
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
|
return query_collection(
|
||||||
|
collection_names=form_data.collection_names,
|
||||||
|
query=form_data.query,
|
||||||
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||||
query_embeddings = generate_ollama_embeddings(
|
query_embeddings = generate_ollama_embeddings(
|
||||||
GenerateEmbeddingsForm(
|
GenerateEmbeddingsForm(
|
||||||
|
@ -326,19 +363,20 @@ def query_collection_handler(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||||
|
query_embeddings = generate_openai_embeddings(
|
||||||
|
model=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
text=form_data.query,
|
||||||
|
key=app.state.RAG_OPENAI_API_KEY,
|
||||||
|
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
|
||||||
return query_embeddings_collection(
|
return query_embeddings_collection(
|
||||||
collection_names=form_data.collection_names,
|
collection_names=form_data.collection_names,
|
||||||
query_embeddings=query_embeddings,
|
query_embeddings=query_embeddings,
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
return query_collection(
|
|
||||||
collection_names=form_data.collection_names,
|
|
||||||
query=form_data.query,
|
|
||||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
|
||||||
embedding_function=app.state.sentence_transformer_ef,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
|
@ -414,24 +452,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||||
log.info(f"deleting existing collection {collection_name}")
|
log.info(f"deleting existing collection {collection_name}")
|
||||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||||
|
|
||||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
||||||
|
|
||||||
for batch in create_batches(
|
|
||||||
api=CHROMA_CLIENT,
|
|
||||||
ids=[str(uuid.uuid1()) for _ in texts],
|
|
||||||
metadatas=metadatas,
|
|
||||||
embeddings=[
|
|
||||||
generate_ollama_embeddings(
|
|
||||||
GenerateEmbeddingsForm(
|
|
||||||
**{"model": RAG_EMBEDDING_MODEL, "prompt": text}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for text in texts
|
|
||||||
],
|
|
||||||
):
|
|
||||||
collection.add(*batch)
|
|
||||||
else:
|
|
||||||
|
|
||||||
collection = CHROMA_CLIENT.create_collection(
|
collection = CHROMA_CLIENT.create_collection(
|
||||||
name=collection_name,
|
name=collection_name,
|
||||||
|
@ -446,6 +467,35 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||||
):
|
):
|
||||||
collection.add(*batch)
|
collection.add(*batch)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||||
|
embeddings = [
|
||||||
|
generate_ollama_embeddings(
|
||||||
|
GenerateEmbeddingsForm(
|
||||||
|
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||||
|
embeddings = [
|
||||||
|
generate_openai_embeddings(
|
||||||
|
model=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
text=text,
|
||||||
|
key=app.state.RAG_OPENAI_API_KEY,
|
||||||
|
url=app.state.RAG_OPENAI_API_BASE_URL,
|
||||||
|
)
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
|
|
||||||
|
for batch in create_batches(
|
||||||
|
api=CHROMA_CLIENT,
|
||||||
|
ids=[str(uuid.uuid1()) for _ in texts],
|
||||||
|
metadatas=metadatas,
|
||||||
|
embeddings=embeddings,
|
||||||
|
):
|
||||||
|
collection.add(*batch)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
|
@ -269,3 +269,26 @@ def get_embedding_model_path(
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
log.exception(f"Cannot determine embedding model snapshot path: {e}")
|
||||||
return embedding_model
|
return embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
def generate_openai_embeddings(
|
||||||
|
model: str, text: str, key: str, url: str = "https://api.openai.com"
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/v1/embeddings",
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
},
|
||||||
|
json={"input": text, "model": model},
|
||||||
|
)
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
if "data" in data:
|
||||||
|
return data["data"][0]["embedding"]
|
||||||
|
else:
|
||||||
|
raise "Something went wrong :/"
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return None
|
||||||
|
|
Loading…
Reference in a new issue