11 lines
357 B
Python
11 lines
357 B
Python
import torch
|
|
from torch.utils.data import TensorDataset
|
|
|
|
|
|
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
|
data = torch.tensor(list(data), dtype=torch.uint8)
|
|
sample_count = data.shape[0] - context_length
|
|
x = data.unfold(0, context_length, 1)[:sample_count]
|
|
y = data[context_length:]
|
|
return TensorDataset(x, y)
|
|
|