This commit is contained in:
Timothy J. Baek 2024-04-10 00:46:09 -07:00
parent f4b87ecb23
commit abfcceecef
2 changed files with 28 additions and 23 deletions

View file

@ -142,43 +142,40 @@ class EmbeddingModelUpdateForm(BaseModel):
async def update_embedding_model( async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
status = True
old_model_path = app.state.RAG_EMBEDDING_MODEL_PATH
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
log.debug(f"form_data.embedding_model: {form_data.embedding_model}") log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
embedding_model_path = None
sentence_transformer_ef = None
try: try:
app.state.RAG_EMBEDDING_MODEL_PATH = get_embedding_model_path( embedding_model_path = get_embedding_model_path(form_data.embedding_model, True)
app.state.RAG_EMBEDDING_MODEL, True if app.state.RAG_EMBEDDING_MODEL_PATH != embedding_model_path:
) sentence_transformer_ef = (
app.state.sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction(
embedding_functions.SentenceTransformerEmbeddingFunction( model_name=embedding_model_path,
model_name=app.state.RAG_EMBEDDING_MODEL_PATH, device=DEVICE_TYPE,
device=DEVICE_TYPE, )
) )
)
except Exception as e: except Exception as e:
log.exception(f"Problem updating embedding model: {e}") log.exception(f"Problem updating embedding model: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e, detail=ERROR_MESSAGES.DEFAULT(e),
) )
if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: if sentence_transformer_ef:
status = False app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_path
app.state.sentence_transformer_ef = sentence_transformer_ef
log.debug( log.debug(
f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}" f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}"
) )
log.debug(f"old_model_path: {old_model_path}")
log.debug(f"status: {status}")
return { return {
"status": status, "status": sentence_transformer_ef != None,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, "embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
} }

View file

@ -35,6 +35,9 @@
k: 4 k: 4
}; };
let embeddingModelConfig = {
embedding_model: ''
};
let embeddingModel = ''; let embeddingModel = '';
const scanHandler = async () => { const scanHandler = async () => {
@ -61,7 +64,13 @@
console.log('Update embedding model attempt:', embeddingModel); console.log('Update embedding model attempt:', embeddingModel);
updateEmbeddingModelLoading = true; updateEmbeddingModelLoading = true;
const res = await updateEmbeddingModel(localStorage.token, { embedding_model: embeddingModel }); const res = await updateEmbeddingModel(localStorage.token, {
embedding_model: embeddingModel
}).catch((error) => {
toast.error(error);
embeddingModel = embeddingModelConfig.embedding_model;
return null;
});
updateEmbeddingModelLoading = false; updateEmbeddingModelLoading = false;
if (res) { if (res) {
@ -99,8 +108,7 @@
chunkOverlap = res.chunk.chunk_overlap; chunkOverlap = res.chunk.chunk_overlap;
} }
const embeddingModelConfig = await getEmbeddingModel(localStorage.token); embeddingModelConfig = await getEmbeddingModel(localStorage.token);
embeddingModel = embeddingModelConfig.embedding_model; embeddingModel = embeddingModelConfig.embedding_model;
querySettings = await getQuerySettings(localStorage.token); querySettings = await getQuerySettings(localStorage.token);