diff --git a/CNN-model/dataset_loaders/EnWik9.py b/CNN-model/dataset_loaders/EnWik9.py index 32698a8..bef57a1 100644 --- a/CNN-model/dataset_loaders/EnWik9.py +++ b/CNN-model/dataset_loaders/EnWik9.py @@ -1,25 +1,43 @@ from datasets import load_dataset +from torch.utils.data import Dataset +import torch from os.path import curdir, join -from .Dataset import Dataset -from torch.utils.data import TensorDataset from typing import Callable class EnWik9DataSet(Dataset): - def __init__(self, root: str = "data", transform: Callable = None): - super().__init__(root, transform) + def __init__(self, root: str = "data", transform: Callable | None = None): + super().__init__() + self.transform = transform + # HuggingFace dataset: string text path = join(curdir, root) - self._root = path - data = load_dataset("haukur/enwik9", cache_dir=path, split="train") + + # Extract raw text text = data["text"] - self.dataset = TensorDataset(text) + + # Convert text (Python string) → bytes → tensor of ints 0–255 + # UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors + byte_data = "".join(text).encode("utf-8", errors="replace") + self.data = torch.tensor(list(byte_data), dtype=torch.long) + + # Model uses fixed 128-length context + self.context_length = 128 def __len__(self): - return len(self.dataset) + # number of sliding windows + return len(self.data) - self.context_length def __getitem__(self, idx): - if self.transform is not None: - return self.transform(self.dataset[idx]) - return self.dataset[idx] + # context window + x = self.data[idx : idx + self.context_length] + + # next byte target + y = self.data[idx + self.context_length] + + if self.transform: + x = self.transform(x) + + return x, y + diff --git a/CNN-model/main_cnn.py b/CNN-model/main_cnn.py index 89bb70e..6b277dd 100644 --- a/CNN-model/main_cnn.py +++ b/CNN-model/main_cnn.py @@ -20,6 +20,7 @@ if __name__ == "__main__": parser.add_argument("--model-path", type=str, required=False) args = parser.parse_args() + print("Loading in the dataset...") if args.method == "train": dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE)) elif args.method == "optuna": @@ -28,9 +29,11 @@ if __name__ == "__main__": raise ValueError(f"Unknown method: {args.method}") dataset_length = len(dataset) + print(f"Dataset size = {dataset_length}") + training_size = ceil(0.8 * dataset_length) - print(f"training set size = {training_size}, validation set size {dataset_length - training_size}") + print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}") train_set, validate_set = torch.utils.data.random_split(dataset, [training_size, dataset_length - training_size]) @@ -40,6 +43,7 @@ if __name__ == "__main__": model = None if args.model_path is not None: + print("Loading the model...") model = torch.load(args.model_path) trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()