Streamline datasets
This commit is contained in:
parent
849bcd7b77
commit
befb1a96a5
8 changed files with 222 additions and 64 deletions
|
|
@ -1,32 +1,63 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from lorem.text import TextLorem
|
||||
from tqdm import tqdm
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class LoremIpsumDataset(Dataset):
|
||||
def __init__(self, root: str | None = None, transform: Callable = None, size: int = 512):
|
||||
super().__init__('lorem_ipsum', root, transform)
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
||||
|
||||
# Generate text and convert to bytes
|
||||
_lorem = TextLorem()
|
||||
_text = ' '.join(_lorem._word() for _ in range(size))
|
||||
|
||||
# Convert text to bytes (UTF-8 encoded)
|
||||
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# Number of possible sequences of length sequence_length
|
||||
return self.dataset.size(0) - self.context_length
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.dataset[idx: idx + self.context_length]
|
||||
y = self.dataset[idx + self.context_length]
|
||||
# Get sequence of characters
|
||||
# x_str = self.text[idx: idx + self.context_length]
|
||||
# y_char = self.text[idx + self.context_length]
|
||||
#
|
||||
# # Convert to tensors
|
||||
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
|
||||
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
|
||||
#
|
||||
# if self.transform is not None:
|
||||
# x = self.transform(x)
|
||||
#
|
||||
# return x, y
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform is not None:
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
|
|
|||
Reference in a new issue