forked from open-webui/open-webui
feat: dynamic embedding model load
This commit is contained in:
parent
ab104d5905
commit
7c127c35fc
1 changed files with 56 additions and 36 deletions
|
@ -35,6 +35,8 @@ from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
from apps.web.models.documents import (
|
from apps.web.models.documents import (
|
||||||
Documents,
|
Documents,
|
||||||
|
@ -63,24 +65,26 @@ from config import (
|
||||||
from constants import ERROR_MESSAGES
|
from constants import ERROR_MESSAGES
|
||||||
|
|
||||||
#
|
#
|
||||||
#if RAG_EMBEDDING_MODEL:
|
# if RAG_EMBEDDING_MODEL:
|
||||||
# sentence_transformer_ef = SentenceTransformer(
|
# sentence_transformer_ef = SentenceTransformer(
|
||||||
# model_name_or_path=RAG_EMBEDDING_MODEL,
|
# model_name_or_path=RAG_EMBEDDING_MODEL,
|
||||||
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
|
||||||
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||||
# )
|
# )
|
||||||
|
|
||||||
if RAG_EMBEDDING_MODEL:
|
|
||||||
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
||||||
model_name=RAG_EMBEDDING_MODEL,
|
|
||||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
|
||||||
)
|
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
app.state.CHUNK_SIZE = CHUNK_SIZE
|
app.state.CHUNK_SIZE = CHUNK_SIZE
|
||||||
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||||
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
app.state.RAG_TEMPLATE = RAG_TEMPLATE
|
||||||
|
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||||
|
app.state.sentence_transformer_ef = (
|
||||||
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
|
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
|
||||||
metadatas = [doc.metadata for doc in docs]
|
metadatas = [doc.metadata for doc in docs]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if RAG_EMBEDDING_MODEL:
|
collection = CHROMA_CLIENT.create_collection(
|
||||||
# if you use docker use the model from the environment variable
|
name=collection_name,
|
||||||
collection = CHROMA_CLIENT.create_collection(
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
name=collection_name, embedding_function=sentence_transformer_ef
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# for local development use the default model
|
|
||||||
collection = CHROMA_CLIENT.create_collection(name=collection_name)
|
|
||||||
|
|
||||||
collection.add(
|
collection.add(
|
||||||
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
|
||||||
|
@ -139,6 +139,38 @@ async def get_status():
|
||||||
"status": True,
|
"status": True,
|
||||||
"chunk_size": app.state.CHUNK_SIZE,
|
"chunk_size": app.state.CHUNK_SIZE,
|
||||||
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
"chunk_overlap": app.state.CHUNK_OVERLAP,
|
||||||
|
"template": app.state.RAG_TEMPLATE,
|
||||||
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/embedding/model")
|
||||||
|
async def get_embedding_model(user=Depends(get_admin_user)):
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModelUpdateForm(BaseModel):
|
||||||
|
embedding_model: str
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/embedding/model/update")
|
||||||
|
async def update_embedding_model(
|
||||||
|
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
|
||||||
|
):
|
||||||
|
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||||
|
app.state.sentence_transformer_ef = (
|
||||||
|
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||||
|
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||||
|
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": True,
|
||||||
|
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -203,17 +235,11 @@ def query_doc(
|
||||||
user=Depends(get_current_user),
|
user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
if RAG_EMBEDDING_MODEL:
|
# if you use docker use the model from the environment variable
|
||||||
# if you use docker use the model from the environment variable
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
name=form_data.collection_name,
|
||||||
name=form_data.collection_name,
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
embedding_function=sentence_transformer_ef,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# for local development use the default model
|
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
|
||||||
name=form_data.collection_name,
|
|
||||||
)
|
|
||||||
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -284,17 +310,11 @@ def query_collection(
|
||||||
|
|
||||||
for collection_name in form_data.collection_names:
|
for collection_name in form_data.collection_names:
|
||||||
try:
|
try:
|
||||||
if RAG_EMBEDDING_MODEL:
|
# if you use docker use the model from the environment variable
|
||||||
# if you use docker use the model from the environment variable
|
collection = CHROMA_CLIENT.get_collection(
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
name=collection_name,
|
||||||
name=collection_name,
|
embedding_function=app.state.sentence_transformer_ef,
|
||||||
embedding_function=sentence_transformer_ef,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
# for local development use the default model
|
|
||||||
collection = CHROMA_CLIENT.get_collection(
|
|
||||||
name=collection_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
result = collection.query(
|
result = collection.query(
|
||||||
query_texts=[form_data.query], n_results=form_data.k
|
query_texts=[form_data.query], n_results=form_data.k
|
||||||
|
|
Loading…
Reference in a new issue