from datasets import load_dataset from os.path import curdir, join from .Dataset import Dataset from torch.utils.data import TensorDataset from typing import Callable class EnWik9DataSet(Dataset): def __init__(self, root: str = "data", transform: Callable = None): super().__init__(root, transform) path = join(curdir, root) self._root = path data = load_dataset("haukur/enwik9", cache_dir=path, split="train") text = data["text"] self.dataset = TensorDataset(text) def __len__(self): return len(self.dataset) def __getitem__(self, idx): if self.transform is not None: return self.transform(self.dataset[idx]) return self.dataset[idx]