import torch from torch.utils.data import TensorDataset import matplotlib.pyplot as plt def make_context_pairs(data: bytes, context_length: int) -> TensorDataset: data = torch.tensor(list(data), dtype=torch.long) sample_count = data.shape[0] - context_length x = data.unfold(0, context_length, 1)[:sample_count] y = data[context_length:] return TensorDataset(x, y) def print_distribution(from_to: tuple[int, int], probabilities: list[float]): plt.hist(range(from_to[0], from_to[1]), weights=probabilities) plt.show() def print_losses(train_losses: list[float], validation_losses: list[float], show=False): plt.plot(train_losses, label="Training loss") plt.plot(validation_losses, label="Validation loss") plt.xlabel("Epoch") plt.ylabel("Loss (cross entropy)") plt.legend() if show: plt.show() plt.savefig("losses.png") def load_data(path: str) -> bytes: with open(path, "rb") as f: return f.read()