fix: accelerator module is not available on HPC pytorch version
This commit is contained in:
parent
f026be49aa
commit
ea9cf12db0
1 changed files with 7 additions and 1 deletions
|
|
@ -8,7 +8,13 @@ from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset
|
||||||
from trainers import OptunaTrainer, Trainer, FullTrainer
|
from trainers import OptunaTrainer, Trainer, FullTrainer
|
||||||
|
|
||||||
BATCH_SIZE = 64
|
BATCH_SIZE = 64
|
||||||
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
DEVICE = "cuda"
|
||||||
|
elif torch.backends.mps.is_available():
|
||||||
|
DEVICE = "mps"
|
||||||
|
else:
|
||||||
|
DEVICE = "cpu"
|
||||||
|
|
||||||
# hyper parameters
|
# hyper parameters
|
||||||
context_length = 128
|
context_length = 128
|
||||||
|
|
|
||||||
Reference in a new issue