from abc import abstractmethod, ABC from os.path import join, curdir from typing import Callable from torch.utils.data import Dataset as TorchDataset """ Author: Tibo De Peuter """ class Dataset(TorchDataset, ABC): """Abstract base class for datasets.""" @abstractmethod def __init__(self, name: str, root: str | None, transform: Callable = None): """ :param root: Relative path to the dataset root directory """ if root is None: root = join(curdir, 'data') self._root = join(root, name) self.transform = transform self.dataset = None @property def root(self): return self._root def __len__(self): return len(self.dataset)