from typing import Callable import torch from os.path import curdir, join from lorem.text import TextLorem from .Dataset import Dataset class LoremIpsumDataset(Dataset): def __init__(self, root: str = "data", transform: Callable = None): super().__init__(root, transform) # Generate text and convert to bytes _lorem = TextLorem() _text = ' '.join(_lorem._word() for _ in range(512)) path = join(curdir, "data") self._root = path # Convert text to bytes (UTF-8 encoded) self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) self.context_length = 128 def __len__(self): # Number of possible sequences of length sequence_length return self.dataset.size(0) - self.context_length def __getitem__(self, idx): x = self.dataset[idx: idx + self.context_length] y = self.dataset[idx + self.context_length] if self.transform is not None: x = self.transform(x) return x, y