forked from open-webui/open-webui
feat: move to native sentence_transformer
This commit is contained in:
parent
22c50f62cb
commit
f3e5700d49
7 changed files with 153 additions and 268 deletions
|
@ -13,7 +13,6 @@ import os, shutil, logging, re
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
|
@ -38,6 +37,7 @@ import mimetypes
|
|||
import uuid
|
||||
import json
|
||||
|
||||
import sentence_transformers
|
||||
|
||||
from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
|
||||
|
||||
|
@ -48,11 +48,8 @@ from apps.web.models.documents import (
|
|||
)
|
||||
|
||||
from apps.rag.utils import (
|
||||
query_doc,
|
||||
query_embeddings_doc,
|
||||
query_collection,
|
||||
query_embeddings_collection,
|
||||
get_embedding_model_path,
|
||||
generate_openai_embeddings,
|
||||
)
|
||||
|
||||
|
@ -69,7 +66,7 @@ from config import (
|
|||
DOCS_DIR,
|
||||
RAG_EMBEDDING_ENGINE,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
RAG_OPENAI_API_BASE_URL,
|
||||
RAG_OPENAI_API_KEY,
|
||||
DEVICE_TYPE,
|
||||
|
@ -101,15 +98,12 @@ app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
|
|||
|
||||
app.state.PDF_EXTRACT_IMAGES = False
|
||||
|
||||
|
||||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=get_embedding_model_path(
|
||||
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
|
||||
),
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
origins = ["*"]
|
||||
|
@ -185,13 +179,10 @@ async def update_embedding_config(
|
|||
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.OPENAI_API_KEY = form_data.openai_config.key
|
||||
else:
|
||||
sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=get_embedding_model_path(
|
||||
form_data.embedding_model, True
|
||||
),
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
sentence_transformer_ef = sentence_transformers.SentenceTransformer(
|
||||
app.state.RAG_EMBEDDING_MODEL,
|
||||
device=DEVICE_TYPE,
|
||||
trust_remote_code=RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
|
||||
)
|
||||
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
app.state.sentence_transformer_ef = sentence_transformer_ef
|
||||
|
@ -294,38 +285,34 @@ def query_doc_handler(
|
|||
form_data: QueryDocForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
return query_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
query_embeddings = app.state.sentence_transformer_ef.encode(
|
||||
form_data.query
|
||||
).tolist()
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=form_data.query,
|
||||
key=app.state.OPENAI_API_KEY,
|
||||
url=app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
else:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
)
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=form_data.query,
|
||||
key=app.state.OPENAI_API_KEY,
|
||||
url=app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
return query_embeddings_doc(
|
||||
collection_name=form_data.collection_name,
|
||||
query=form_data.query,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
@ -348,36 +335,31 @@ def query_collection_handler(
|
|||
):
|
||||
try:
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
return query_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query=form_data.query,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
else:
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
query_embeddings = app.state.sentence_transformer_ef.encode(
|
||||
form_data.query
|
||||
).tolist()
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
query_embeddings = generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{
|
||||
"model": app.state.RAG_EMBEDDING_MODEL,
|
||||
"prompt": form_data.query,
|
||||
}
|
||||
)
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=form_data.query,
|
||||
key=app.state.OPENAI_API_KEY,
|
||||
url=app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return query_embeddings_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
query_embeddings = generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=form_data.query,
|
||||
key=app.state.OPENAI_API_KEY,
|
||||
url=app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
|
||||
return query_embeddings_collection(
|
||||
collection_names=form_data.collection_names,
|
||||
query_embeddings=query_embeddings,
|
||||
k=form_data.k if form_data.k else app.state.TOP_K,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
@ -445,6 +427,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"store_docs_in_vector_db {docs} {collection_name}")
|
||||
|
||||
texts = [doc.page_content for doc in docs]
|
||||
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
try:
|
||||
|
@ -454,52 +438,38 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
|||
log.info(f"deleting existing collection {collection_name}")
|
||||
CHROMA_CLIENT.delete_collection(name=collection_name)
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "":
|
||||
|
||||
collection = CHROMA_CLIENT.create_collection(
|
||||
name=collection_name,
|
||||
embedding_function=app.state.sentence_transformer_ef,
|
||||
)
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
documents=texts,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
else:
|
||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
||||
|
||||
if app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
embeddings = [
|
||||
generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
||||
)
|
||||
embeddings = app.state.sentence_transformer_ef.encode(texts).tolist()
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "ollama":
|
||||
embeddings = [
|
||||
generate_ollama_embeddings(
|
||||
GenerateEmbeddingsForm(
|
||||
**{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
embeddings = [
|
||||
generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=text,
|
||||
key=app.state.OPENAI_API_KEY,
|
||||
url=app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
elif app.state.RAG_EMBEDDING_ENGINE == "openai":
|
||||
embeddings = [
|
||||
generate_openai_embeddings(
|
||||
model=app.state.RAG_EMBEDDING_MODEL,
|
||||
text=text,
|
||||
key=app.state.OPENAI_API_KEY,
|
||||
url=app.state.OPENAI_API_BASE_URL,
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
):
|
||||
collection.add(*batch)
|
||||
for batch in create_batches(
|
||||
api=CHROMA_CLIENT,
|
||||
ids=[str(uuid.uuid1()) for _ in texts],
|
||||
metadatas=metadatas,
|
||||
embeddings=embeddings,
|
||||
documents=texts,
|
||||
):
|
||||
collection.add(*batch)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue