Merge pull request #1419 from lainedfles/embedding-model-fix-and-manual-update

feat: improve embedding model update & resolve network dependency
This commit is contained in:
Timothy Jaeryang Baek 2024-04-10 01:10:07 -07:00 committed by GitHub
commit b9cadff16b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 438 additions and 210 deletions

View file

@ -13,7 +13,6 @@ import os, shutil, logging, re
from pathlib import Path
from typing import List
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from chromadb.utils.batch_utils import create_batches
@ -46,7 +45,7 @@ from apps.web.models.documents import (
DocumentResponse,
)
from apps.rag.utils import query_doc, query_collection
from apps.rag.utils import query_doc, query_collection, get_embedding_model_path
from utils.misc import (
calculate_sha256,
@ -60,6 +59,7 @@ from config import (
UPLOAD_DIR,
DOCS_DIR,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
CHROMA_CLIENT,
CHUNK_SIZE,
@ -78,12 +78,18 @@ app.state.PDF_EXTRACT_IMAGES = False
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.TOP_K = 4
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
model_name=get_embedding_model_path(
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
),
device=DEVICE_TYPE,
)
)
@ -135,17 +141,33 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=DEVICE_TYPE,
)
log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
try:
sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=get_embedding_model_path(form_data.embedding_model, True),
device=DEVICE_TYPE,
)
)
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = sentence_transformer_ef
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/config")