forked from open-webui/open-webui
		
	docker improvements & changed universal device type env for different models used
This commit is contained in:
		
							parent
							
								
									132d741c55
								
							
						
					
					
						commit
						1f6739337b
					
				
					 4 changed files with 36 additions and 19 deletions
				
			
		|  | @ -21,7 +21,11 @@ from utils.utils import ( | |||
| ) | ||||
| from utils.misc import calculate_sha256 | ||||
| 
 | ||||
| from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR | ||||
| from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE | ||||
| 
 | ||||
| if DEVICE_TYPE != "cuda": | ||||
|     whisper_device_type = "cpu" | ||||
| 
 | ||||
| 
 | ||||
| app = FastAPI() | ||||
| app.add_middleware( | ||||
|  | @ -56,7 +60,7 @@ def transcribe( | |||
| 
 | ||||
|         model = WhisperModel( | ||||
|             WHISPER_MODEL, | ||||
|             device="auto", | ||||
|             device=whisper_device_type, | ||||
|             compute_type="int8", | ||||
|             download_root=WHISPER_MODEL_DIR, | ||||
|         ) | ||||
|  |  | |||
|  | @ -57,7 +57,7 @@ from config import ( | |||
|     UPLOAD_DIR, | ||||
|     DOCS_DIR, | ||||
|     RAG_EMBEDDING_MODEL, | ||||
|     RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|     DEVICE_TYPE, | ||||
|     CHROMA_CLIENT, | ||||
|     CHUNK_SIZE, | ||||
|     CHUNK_OVERLAP, | ||||
|  | @ -87,7 +87,7 @@ app.state.TOP_K = 4 | |||
| app.state.sentence_transformer_ef = ( | ||||
|     embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|         model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|         device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|         device=DEVICE_TYPE, | ||||
|     ) | ||||
| ) | ||||
| 
 | ||||
|  | @ -175,7 +175,7 @@ async def update_embedding_model( | |||
|     app.state.sentence_transformer_ef = ( | ||||
|         embedding_functions.SentenceTransformerEmbeddingFunction( | ||||
|             model_name=app.state.RAG_EMBEDDING_MODEL, | ||||
|             device=RAG_EMBEDDING_MODEL_DEVICE_TYPE, | ||||
|             device=DEVICE_TYPE, | ||||
|         ) | ||||
|     ) | ||||
| 
 | ||||
|  |  | |||
|  | @ -330,8 +330,8 @@ CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" | |||
| # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2) | ||||
| RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2") | ||||
| # device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance | ||||
| RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get( | ||||
|     "RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu" | ||||
| DEVICE_TYPE = os.environ.get( | ||||
|     "DEVICE_TYPE", "cpu" | ||||
| ) | ||||
| CHROMA_CLIENT = chromadb.PersistentClient( | ||||
|     path=CHROMA_DATA_PATH, | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jannik Streidl
						Jannik Streidl