from math import ceil from typing import Callable from lorem.text import TextLorem from tqdm import tqdm from .Dataset import Dataset class LoremIpsumDataset(Dataset): def __init__(self, root: str | None = None, split: str = 'train', transform: Callable = None, size: int = 2**30 ): super().__init__('lorem_ipsum', root, split, transform, size) _lorem = TextLorem() self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data")) self.size = size self.context_length = 128 self.process_data() split_point = ceil(self.chunk_offsets[-1] * 0.8) if self.split == 'train': self.start_byte = 0 self.end_byte = split_point elif self.split == 'validation': self.start_byte = split_point self.end_byte = self.chunk_offsets[-1] else: raise ValueError("split must be 'train' or 'validation'") print("Done initializing dataset") def __len__(self): return self.end_byte - self.start_byte - self.context_length def __getitem__(self, idx): # Get sequence of characters # x_str = self.text[idx: idx + self.context_length] # y_char = self.text[idx + self.context_length] # # # Convert to tensors # x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long) # y = torch.tensor(ord(y_char) % 256, dtype=torch.long) # # if self.transform is not None: # x = self.transform(x) # # return x, y x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length] y = self.tensor[self.start_byte + idx + self.context_length] if self.transform: x = self.transform(x) return x, y