More format fixes

This commit is contained in:
Self Denial 2024-04-04 12:07:42 -06:00
parent bcf79c8366
commit 075fbedb02
2 changed files with 12 additions and 5 deletions

View file

@ -141,17 +141,21 @@ async def update_embedding_model(
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
log.debug(f"form_data.embedding_model: {form_data.embedding_model}") log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
log.info(f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}") log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try: try:
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(app.state.RAG_EMBEDDING_MODEL, True) app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_get_path(
app.state.RAG_EMBEDDING_MODEL, True
)
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL_PATH, model_name=app.state.RAG_EMBEDDING_MODEL_PATH,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
) )
) )
except Exception as e: except Exception as e:
log.exception(f"Problem updating embedding model: {e}") log.exception(f"Problem updating embedding model: {e}")
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
@ -159,9 +163,11 @@ async def update_embedding_model(
) )
if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path: if app.state.RAG_EMBEDDING_MODEL_PATH == old_model_path:
status = False status = False
log.debug(f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}") log.debug(
f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}"
)
log.debug(f"old_model_path: {old_model_path}") log.debug(f"old_model_path: {old_model_path}")
log.debug(f"status: {status}") log.debug(f"status: {status}")

View file

@ -191,6 +191,7 @@ def rag_messages(docs, messages, template, k, embedding_function):
return messages return messages
def embedding_model_get_path( def embedding_model_get_path(
embedding_model: str, update_embedding_model: bool = False embedding_model: str, update_embedding_model: bool = False
): ):