From 713934edb65991e93766974cc981ee10df26404d Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Sat, 20 Apr 2024 15:21:52 -0500 Subject: [PATCH] refac --- backend/apps/audio/main.py | 32 ++++++++++++++++++++++++++++++++ backend/apps/rag/main.py | 28 ++++++++++++++-------------- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/backend/apps/audio/main.py b/backend/apps/audio/main.py index 94c1c359..f7ce6fec 100644 --- a/backend/apps/audio/main.py +++ b/backend/apps/audio/main.py @@ -15,6 +15,8 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware from faster_whisper import WhisperModel +from pydantic import BaseModel + import requests import hashlib @@ -67,6 +69,36 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/") SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) +class OpenAIConfigUpdateForm(BaseModel): + url: str + key: str + + +@app.get("/config") +async def get_openai_config(user=Depends(get_admin_user)): + return { + "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + } + + +@app.post("/config/update") +async def update_openai_config( + form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user) +): + if form_data.key == "": + raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) + + app.state.OPENAI_API_BASE_URL = form_data.url + app.state.OPENAI_API_KEY = form_data.key + + return { + "status": True, + "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, + "OPENAI_API_KEY": app.state.OPENAI_API_KEY, + } + + @app.post("/speech") async def speech(request: Request, user=Depends(get_verified_user)): idx = None diff --git a/backend/apps/rag/main.py b/backend/apps/rag/main.py index 47ffc017..ac8410db 100644 --- a/backend/apps/rag/main.py +++ b/backend/apps/rag/main.py @@ -96,8 +96,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_TEMPLATE = RAG_TEMPLATE -app.state.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL -app.state.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY +app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL +app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.PDF_EXTRACT_IMAGES = False @@ -150,8 +150,8 @@ async def get_embedding_config(user=Depends(get_admin_user)): "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, "openai_config": { - "url": app.state.RAG_OPENAI_API_BASE_URL, - "key": app.state.RAG_OPENAI_API_KEY, + "url": app.state.OPENAI_API_BASE_URL, + "key": app.state.OPENAI_API_KEY, }, } @@ -182,8 +182,8 @@ async def update_embedding_config( app.state.sentence_transformer_ef = None if form_data.openai_config != None: - app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url - app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key + app.state.OPENAI_API_BASE_URL = form_data.openai_config.url + app.state.OPENAI_API_KEY = form_data.openai_config.key else: sentence_transformer_ef = ( embedding_functions.SentenceTransformerEmbeddingFunction( @@ -201,8 +201,8 @@ async def update_embedding_config( "embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_model": app.state.RAG_EMBEDDING_MODEL, "openai_config": { - "url": app.state.RAG_OPENAI_API_BASE_URL, - "key": app.state.RAG_OPENAI_API_KEY, + "url": app.state.OPENAI_API_BASE_URL, + "key": app.state.OPENAI_API_KEY, }, } @@ -317,8 +317,8 @@ def query_doc_handler( query_embeddings = generate_openai_embeddings( model=app.state.RAG_EMBEDDING_MODEL, text=form_data.query, - key=app.state.RAG_OPENAI_API_KEY, - url=app.state.RAG_OPENAI_API_BASE_URL, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, ) return query_embeddings_doc( @@ -369,8 +369,8 @@ def query_collection_handler( query_embeddings = generate_openai_embeddings( model=app.state.RAG_EMBEDDING_MODEL, text=form_data.query, - key=app.state.RAG_OPENAI_API_KEY, - url=app.state.RAG_OPENAI_API_BASE_URL, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, ) return query_embeddings_collection( @@ -486,8 +486,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b generate_openai_embeddings( model=app.state.RAG_EMBEDDING_MODEL, text=text, - key=app.state.RAG_OPENAI_API_KEY, - url=app.state.RAG_OPENAI_API_BASE_URL, + key=app.state.OPENAI_API_KEY, + url=app.state.OPENAI_API_BASE_URL, ) for text in texts ]