import torch from torch.utils.data import TensorDataset def make_context_pairs(data: bytes, context_length: int) -> TensorDataset: data = torch.tensor(data, dtype=torch.uint8) sample_count = data.shape[0] - context_length x = data.unfold(0, context_length, 1)[:sample_count] y = data[context_length:] return TensorDataset(x, y)