from typing import Callable import torch from lorem.text import TextLorem from .Dataset import Dataset class LoremIpsumDataset(Dataset): def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512): super().__init__('lorem_ipsum', root, transform) # Generate text and convert to bytes _lorem = TextLorem() _text = ' '.join(_lorem._word() for _ in range(size)) # Convert text to bytes (UTF-8 encoded) self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long) self.context_length = 128 def __len__(self): # Number of possible sequences of length sequence_length return self.dataset.size(0) - self.context_length def __getitem__(self, idx): x = self.dataset[idx: idx + self.context_length] y = self.dataset[idx + self.context_length] if self.transform is not None: x = self.transform(x) return x, y