From 5abe0089cb7bf67b639ee241be3f824d58e86cee Mon Sep 17 00:00:00 2001 From: Jannik Streidl Date: Mon, 18 Mar 2024 17:08:34 +0100 Subject: [PATCH] cuda support --- Dockerfile | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/Dockerfile b/Dockerfile index 1ccfdb62..33656bba 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,7 +37,7 @@ ENV OPENAI_API_KEY="" \ SCARF_NO_ANALYTICS=true \ DO_NOT_TRACK=true -#### Preloaded models ########################################################## +#### Preloaded models ######################################################### ## whisper TTS Settings ## ENV WHISPER_MODEL="base" \ WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models" @@ -48,19 +48,32 @@ ENV WHISPER_MODEL="base" \ # for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" \ - # device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance - RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu" \ RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \ - SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" + SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" \ + # device type for whisper tts and embbeding models - "cpu" (default) or "mps" (apple silicon) - choosing this right can lead to better performance + # Important: + # If you want to use CUDA you need to install the nvidia-container-toolkit (https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) + # you can set this to "cuda" but its recomended to use --build-arg CUDA_ENABLED=true flag when building the image + RAG_EMBEDDING_MODEL_DEVICE_TYPE="cuda" +# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance #### Preloaded models ########################################################## WORKDIR /app/backend - # install python dependencies COPY ./backend/requirements.txt ./requirements.txt -RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir \ - && pip3 install -r requirements.txt --no-cache-dir +RUN pip3 install -r requirements.txt --no-cache-dir + +RUN if [ "$RAG_EMBEDDING_MODEL_DEVICE_TYPE" = "cuda" ]; then \ + echo "CUDA enabled" && \ + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir; \ + else \ + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ + python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"; \ + fi + +# preload tts model +RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" # install required packages RUN apt-get update \ @@ -71,10 +84,7 @@ RUN apt-get update \ # cleanup && rm -rf /var/lib/apt/lists/* -# preload embedding model -RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])" -# preload tts model -RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" + # copy embedding weight from build # RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2