fix: Manual selection of device
This commit is contained in:
parent
20cbd61a82
commit
962230b6e1
2 changed files with 19 additions and 7 deletions
8
main.py
8
main.py
|
|
@ -1,17 +1,13 @@
|
||||||
import torch
|
|
||||||
|
|
||||||
from src.args import parse_arguments
|
from src.args import parse_arguments
|
||||||
from src.process import compress
|
from src.process import compress
|
||||||
from src.train import train
|
from src.train import train
|
||||||
|
from src.utils import determine_device
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args, print_help = parse_arguments()
|
args, print_help = parse_arguments()
|
||||||
|
|
||||||
if torch.accelerator.is_available():
|
device = determine_device()
|
||||||
device = torch.accelerator.current_accelerator().type
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
print(f"Running on device: {device}...")
|
print(f"Running on device: {device}...")
|
||||||
|
|
||||||
match args.mode:
|
match args.mode:
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset
|
from torch.utils.data import TensorDataset
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
|
|
||||||
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||||
|
|
@ -12,10 +12,12 @@ def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||||
y = data[context_length:]
|
y = data[context_length:]
|
||||||
return TensorDataset(x, y)
|
return TensorDataset(x, y)
|
||||||
|
|
||||||
|
|
||||||
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
||||||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False):
|
def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False):
|
||||||
plt.plot(train_losses, label="Training loss")
|
plt.plot(train_losses, label="Training loss")
|
||||||
plt.plot(validation_losses, label="Validation loss")
|
plt.plot(validation_losses, label="Validation loss")
|
||||||
|
|
@ -33,6 +35,20 @@ def print_losses(train_losses: list[float], validation_losses: list[float], file
|
||||||
plt.savefig(filename)
|
plt.savefig(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def determine_device():
|
||||||
|
# NVIDIA GPUs (most HPC clusters)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.device("cuda")
|
||||||
|
# Apple Silicon (macOS)
|
||||||
|
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||||
|
return torch.device("mps")
|
||||||
|
# Intel GPUs (oneAPI)
|
||||||
|
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
|
return torch.device("xpu")
|
||||||
|
else:
|
||||||
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
|
||||||
def load_data(path: str) -> bytes:
|
def load_data(path: str) -> bytes:
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
|
||||||
Reference in a new issue