diff --git a/CNN-model/main_cnn.py b/CNN-model/main_cnn.py index 6b277dd..530122e 100644 --- a/CNN-model/main_cnn.py +++ b/CNN-model/main_cnn.py @@ -8,7 +8,13 @@ from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset from trainers import OptunaTrainer, Trainer, FullTrainer BATCH_SIZE = 64 -DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu" + +if torch.cuda.is_available(): + DEVICE = "cuda" +elif torch.backends.mps.is_available(): + DEVICE = "mps" +else: + DEVICE = "cpu" # hyper parameters context_length = 128