From ea9cf12db06080638b8bd9bdc66b4a3ab322f20f Mon Sep 17 00:00:00 2001 From: Robin Meersman Date: Fri, 28 Nov 2025 09:58:55 +0100 Subject: [PATCH] fix: accelerator module is not available on HPC pytorch version --- CNN-model/main_cnn.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/CNN-model/main_cnn.py b/CNN-model/main_cnn.py index 6b277dd..530122e 100644 --- a/CNN-model/main_cnn.py +++ b/CNN-model/main_cnn.py @@ -8,7 +8,13 @@ from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset from trainers import OptunaTrainer, Trainer, FullTrainer BATCH_SIZE = 64 -DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" + +if torch.cuda.is_available(): + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" # hyper parameters context_length = 128