fix: accelerator module is not available on HPC pytorch version

This commit is contained in:
Robin Meersman 2025-11-28 09:58:55 +01:00
parent f026be49aa
commit ea9cf12db0

View file

@ -8,7 +8,13 @@ from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset
from trainers import OptunaTrainer, Trainer, FullTrainer
BATCH_SIZE = 64
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
if torch.cuda.is_available():
DEVICE = "cuda"
elif torch.backends.mps.is_available():
DEVICE = "mps"
else:
DEVICE = "cpu"
# hyper parameters
context_length = 128