forked from open-webui/open-webui
docker improvements & changed universal device type env for different models used
This commit is contained in:
parent
132d741c55
commit
1f6739337b
4 changed files with 36 additions and 19 deletions
|
@ -21,7 +21,11 @@ from utils.utils import (
|
|||
)
|
||||
from utils.misc import calculate_sha256
|
||||
|
||||
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
|
||||
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE
|
||||
|
||||
if DEVICE_TYPE != "cuda":
|
||||
whisper_device_type = "cpu"
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
|
@ -56,7 +60,7 @@ def transcribe(
|
|||
|
||||
model = WhisperModel(
|
||||
WHISPER_MODEL,
|
||||
device="auto",
|
||||
device=whisper_device_type,
|
||||
compute_type="int8",
|
||||
download_root=WHISPER_MODEL_DIR,
|
||||
)
|
||||
|
|
|
@ -57,7 +57,7 @@ from config import (
|
|||
UPLOAD_DIR,
|
||||
DOCS_DIR,
|
||||
RAG_EMBEDDING_MODEL,
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
DEVICE_TYPE,
|
||||
CHROMA_CLIENT,
|
||||
CHUNK_SIZE,
|
||||
CHUNK_OVERLAP,
|
||||
|
@ -87,7 +87,7 @@ app.state.TOP_K = 4
|
|||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -175,7 +175,7 @@ async def update_embedding_model(
|
|||
app.state.sentence_transformer_ef = (
|
||||
embedding_functions.SentenceTransformerEmbeddingFunction(
|
||||
model_name=app.state.RAG_EMBEDDING_MODEL,
|
||||
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
|
||||
device=DEVICE_TYPE,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -330,8 +330,8 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
|
|||
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
|
||||
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
|
||||
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
|
||||
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
|
||||
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
|
||||
DEVICE_TYPE = os.environ.get(
|
||||
"DEVICE_TYPE", "cpu"
|
||||
)
|
||||
CHROMA_CLIENT = chromadb.PersistentClient(
|
||||
path=CHROMA_DATA_PATH,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue