This commit is contained in:
Timothy J. Baek 2024-04-20 15:21:52 -05:00
parent 710850e442
commit 713934edb6
2 changed files with 46 additions and 14 deletions

View file

@ -15,6 +15,8 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from pydantic import BaseModel
import requests import requests
import hashlib import hashlib
@ -67,6 +69,36 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
@app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/config/update")
async def update_openai_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
}
@app.post("/speech") @app.post("/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None idx = None

View file

@ -96,8 +96,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = False app.state.PDF_EXTRACT_IMAGES = False
@ -150,8 +150,8 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": { "openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL, "url": app.state.OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY, "key": app.state.OPENAI_API_KEY,
}, },
} }
@ -182,8 +182,8 @@ async def update_embedding_config(
app.state.sentence_transformer_ef = None app.state.sentence_transformer_ef = None
if form_data.openai_config != None: if form_data.openai_config != None:
app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key app.state.OPENAI_API_KEY = form_data.openai_config.key
else: else:
sentence_transformer_ef = ( sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
@ -201,8 +201,8 @@ async def update_embedding_config(
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"openai_config": { "openai_config": {
"url": app.state.RAG_OPENAI_API_BASE_URL, "url": app.state.OPENAI_API_BASE_URL,
"key": app.state.RAG_OPENAI_API_KEY, "key": app.state.OPENAI_API_KEY,
}, },
} }
@ -317,8 +317,8 @@ def query_doc_handler(
query_embeddings = generate_openai_embeddings( query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL, model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query, text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY, key=app.state.OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL, url=app.state.OPENAI_API_BASE_URL,
) )
return query_embeddings_doc( return query_embeddings_doc(
@ -369,8 +369,8 @@ def query_collection_handler(
query_embeddings = generate_openai_embeddings( query_embeddings = generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL, model=app.state.RAG_EMBEDDING_MODEL,
text=form_data.query, text=form_data.query,
key=app.state.RAG_OPENAI_API_KEY, key=app.state.OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL, url=app.state.OPENAI_API_BASE_URL,
) )
return query_embeddings_collection( return query_embeddings_collection(
@ -486,8 +486,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
generate_openai_embeddings( generate_openai_embeddings(
model=app.state.RAG_EMBEDDING_MODEL, model=app.state.RAG_EMBEDDING_MODEL,
text=text, text=text,
key=app.state.RAG_OPENAI_API_KEY, key=app.state.OPENAI_API_KEY,
url=app.state.RAG_OPENAI_API_BASE_URL, url=app.state.OPENAI_API_BASE_URL,
) )
for text in texts for text in texts
] ]