from typing import Callable import torch from os.path import curdir, join from lorem.text import TextLorem from .Dataset import Dataset class LoremIpsumDataset(Dataset): def __init__(self, root: str = "data", transform: Callable = None): super().__init__(root, transform) # Generate text and convert to bytes _lorem = TextLorem() _text = ' '.join(_lorem._word() for _ in range(512)) path = join(curdir, "data") self._root = path # Convert text to bytes (UTF-8 encoded) self.dataset = torch.tensor([ord(c) for c in list(_text)], dtype=torch.long) sequence_count = self.dataset.shape[0] // 128 # how many vectors of 128 elements can we make self.dataset = self.dataset[:sequence_count * 128] self.dataset = self.dataset.view(-1, 128) print(self.dataset.shape) def __len__(self): # Number of possible sequences of length sequence_length return self.dataset.size(0) def __getitem__(self, idx): if self.transform is not None: return self.transform(self.dataset[idx]) return self.dataset[idx]