From 962230b6e1492e7c84ce01660b9a23198a7ed591 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Wed, 10 Dec 2025 09:57:24 +0100 Subject: [PATCH] fix: Manual selection of device --- main.py | 8 ++------ src/utils/utils.py | 18 +++++++++++++++++- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/main.py b/main.py index 5c74ba8..d2ac876 100644 --- a/main.py +++ b/main.py @@ -1,17 +1,13 @@ -import torch - from src.args import parse_arguments from src.process import compress from src.train import train +from src.utils import determine_device def main(): args, print_help = parse_arguments() - if torch.accelerator.is_available(): - device = torch.accelerator.current_accelerator().type - else: - device = "cpu" + device = determine_device() print(f"Running on device: {device}...") match args.mode: diff --git a/src/utils/utils.py b/src/utils/utils.py index 24fa61d..4929f20 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,8 +1,8 @@ from os import path +import matplotlib.pyplot as plt import torch from torch.utils.data import TensorDataset -import matplotlib.pyplot as plt def make_context_pairs(data: bytes, context_length: int) -> TensorDataset: @@ -12,10 +12,12 @@ def make_context_pairs(data: bytes, context_length: int) -> TensorDataset: y = data[context_length:] return TensorDataset(x, y) + def print_distribution(from_to: tuple[int, int], probabilities: list[float]): plt.hist(range(from_to[0], from_to[1]), weights=probabilities) plt.show() + def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False): plt.plot(train_losses, label="Training loss") plt.plot(validation_losses, label="Validation loss") @@ -33,6 +35,20 @@ def print_losses(train_losses: list[float], validation_losses: list[float], file plt.savefig(filename) +def determine_device(): + # NVIDIA GPUs (most HPC clusters) + if torch.cuda.is_available(): + return torch.device("cuda") + # Apple Silicon (macOS) + elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): + return torch.device("mps") + # Intel GPUs (oneAPI) + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu") + else: + return torch.device("cpu") + + def load_data(path: str) -> bytes: with open(path, "rb") as f: return f.read()