from abc import abstractmethod, ABC from os.path import join, curdir from typing import Callable import torch from torch import Tensor from torch.utils.data import Dataset as TorchDataset from tqdm import tqdm """ Author: Tibo De Peuter """ class Dataset(TorchDataset, ABC): """Abstract base class for datasets.""" @abstractmethod def __init__(self, name: str, root: str | None, split: str = 'train', transform: Callable = None, size: int = -1 ): """ :param root: Path to the dataset root directory :param split: The dataset split, e.g. 'train', 'validation', 'test' :param size: Override the maximum size of the dataset, useful for debugging """ if root is None: root = join(curdir, 'data') self._root = join(root, name) self.split = split self.transform = transform self.size = size self.data = None self.chunk_offsets: list[int] = [] self.bytes: bytes = bytes() self.tensor: Tensor = torch.tensor([]) @property def root(self): return self._root def __len__(self): return len(self.dataset) def process_data(self): if self.size == -1: # Just use the whole dataset self.bytes = ''.join(tqdm(self.data, desc="Encoding data")).encode('utf-8', errors='replace') else: # Use only partition, calculate offsets self.chunk_offsets = self.get_offsets() self.bytes = ''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data")).encode('utf-8', errors='replace') self.tensor = torch.tensor(list(self.bytes), dtype=torch.long) def get_offsets(self): """ Calculate for each chunk how many bytes came before it """ offsets = [0] while len(offsets) <= len(self.data) and (self.size == -1 or offsets[-1] < self.size): idx = len(offsets) - 1 offsets.append(offsets[idx] + len(self.data[idx])) print(offsets) return offsets def get_chunked_item(self, idx: int, offsets: list[int], context_length: int): item = '' # Determine first chunk in which item is located chunk_idx = 0 while idx >= offsets[chunk_idx]: chunk_idx += 1 chunk_idx -= 1 # Extract item from chunks chunk = str(self.data[chunk_idx]) chunk_start = offsets[chunk_idx] chunk_item_start = idx - chunk_start item_len_remaining = context_length + 1 assert len(item) + item_len_remaining == context_length + 1 while chunk_item_start + item_len_remaining > len(chunk): adding_now_len = len(chunk) - chunk_item_start item += chunk[chunk_item_start:] chunk_idx += 1 chunk = str(self.data[chunk_idx]) chunk_item_start = 0 item_len_remaining -= adding_now_len assert len(item) + item_len_remaining == context_length + 1 item += chunk[chunk_item_start: chunk_item_start + item_len_remaining] assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}" # Transform to tensor data = ''.join(item).encode('utf-8', errors='replace') t = torch.tensor(list(data), dtype=torch.long) x, y = t[:-1], t[-1] if self.transform: x = self.transform(x) return x, y