115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
from abc import abstractmethod, ABC
|
|
from os.path import join, curdir
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.utils.data import Dataset as TorchDataset
|
|
from tqdm import tqdm
|
|
|
|
"""
|
|
Author: Tibo De Peuter
|
|
"""
|
|
|
|
|
|
class Dataset(TorchDataset, ABC):
|
|
"""Abstract base class for datasets."""
|
|
|
|
@abstractmethod
|
|
def __init__(self,
|
|
name: str,
|
|
root: str | None,
|
|
split: str = 'train',
|
|
transform: Callable = None,
|
|
size: int = -1
|
|
):
|
|
"""
|
|
:param root: Path to the dataset root directory
|
|
:param split: The dataset split, e.g. 'train', 'validation', 'test'
|
|
:param size: Override the maximum size of the dataset, useful for debugging
|
|
"""
|
|
if root is None:
|
|
root = join(curdir, 'data')
|
|
|
|
self._root = join(root, name)
|
|
self.split = split
|
|
self.transform = transform
|
|
self.size = size
|
|
self.data = None
|
|
|
|
self.chunk_offsets: list[int] = []
|
|
self.bytes: bytes = bytes()
|
|
self.tensor: Tensor = torch.tensor([])
|
|
|
|
@property
|
|
def root(self):
|
|
return self._root
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|
|
|
|
def process_data(self):
|
|
if self.size == -1:
|
|
# Just use the whole dataset
|
|
self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace')
|
|
else:
|
|
# Use only partition, calculate offsets
|
|
self.chunk_offsets = self.get_offsets()
|
|
self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace')
|
|
|
|
self.tensor = torch.tensor(list(self.bytes), dtype=torch.long)
|
|
|
|
def get_offsets(self):
|
|
"""
|
|
Calculate for each chunk how many bytes came before it
|
|
"""
|
|
offsets = [0]
|
|
while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size):
|
|
idx = len(offsets) - 1
|
|
offsets.append(offsets[idx] + len(self.data[idx]))
|
|
print(offsets)
|
|
return offsets
|
|
|
|
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
|
item = ''
|
|
|
|
# Determine first chunk in which item is located
|
|
chunk_idx = 0
|
|
while idx >= offsets[chunk_idx]:
|
|
chunk_idx += 1
|
|
chunk_idx -= 1
|
|
|
|
# Extract item from chunks
|
|
chunk = str(self.data[chunk_idx])
|
|
chunk_start = offsets[chunk_idx]
|
|
|
|
chunk_item_start = idx - chunk_start
|
|
item_len_remaining = context_length + 1
|
|
|
|
assert len(item) + item_len_remaining == context_length + 1
|
|
|
|
while chunk_item_start + item_len_remaining > len(chunk):
|
|
adding_now_len = len(chunk) - chunk_item_start
|
|
item += chunk[chunk_item_start:]
|
|
|
|
chunk_idx += 1
|
|
chunk = str(self.data[chunk_idx])
|
|
|
|
chunk_item_start = 0
|
|
item_len_remaining -= adding_now_len
|
|
|
|
assert len(item) + item_len_remaining == context_length + 1
|
|
|
|
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
|
|
|
|
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
|
|
|
|
# Transform to tensor
|
|
data = ''.join(item).encode('utf-8', errors='replace')
|
|
t = torch.tensor(list(data), dtype=torch.long)
|
|
x, y = t[:-1], t[-1]
|
|
|
|
if self.transform:
|
|
x = self.transform(x)
|
|
|
|
return x, y
|