fix: accelerator module is not available on HPC pytorch version
This commit is contained in:
parent
f026be49aa
commit
ea9cf12db0
1 changed files with 7 additions and 1 deletions
|
|
@ -8,7 +8,13 @@ from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset
|
|||
from trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
|
||||
BATCH_SIZE = 64
|
||||
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
|
||||
# hyper parameters
|
||||
context_length = 128
|
||||
|
|
|
|||
Reference in a new issue