fix: enwik dataset fix

This commit is contained in:
Robin Meersman 2025-11-28 09:27:37 +01:00
parent fe207962de
commit 0577eee601
2 changed files with 34 additions and 12 deletions

View file

@ -1,25 +1,43 @@
from datasets import load_dataset from datasets import load_dataset
from torch.utils.data import Dataset
import torch
from os.path import curdir, join from os.path import curdir, join
from .Dataset import Dataset
from torch.utils.data import TensorDataset
from typing import Callable from typing import Callable
class EnWik9DataSet(Dataset): class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable = None): def __init__(self, root: str = "data", transform: Callable | None = None):
super().__init__(root, transform) super().__init__()
self.transform = transform
# HuggingFace dataset: string text
path = join(curdir, root) path = join(curdir, root)
self._root = path
data = load_dataset("haukur/enwik9", cache_dir=path, split="train") data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
# Extract raw text
text = data["text"] text = data["text"]
self.dataset = TensorDataset(text)
# Convert text (Python string) → bytes → tensor of ints 0255
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
byte_data = "".join(text).encode("utf-8", errors="replace")
self.data = torch.tensor(list(byte_data), dtype=torch.long)
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self): def __len__(self):
return len(self.dataset) # number of sliding windows
return len(self.data) - self.context_length
def __getitem__(self, idx): def __getitem__(self, idx):
if self.transform is not None: # context window
return self.transform(self.dataset[idx]) x = self.data[idx : idx + self.context_length]
return self.dataset[idx]
# next byte target
y = self.data[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -20,6 +20,7 @@ if __name__ == "__main__":
parser.add_argument("--model-path", type=str, required=False) parser.add_argument("--model-path", type=str, required=False)
args = parser.parse_args() args = parser.parse_args()
print("Loading in the dataset...")
if args.method == "train": if args.method == "train":
dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE)) dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE))
elif args.method == "optuna": elif args.method == "optuna":
@ -28,9 +29,11 @@ if __name__ == "__main__":
raise ValueError(f"Unknown method: {args.method}") raise ValueError(f"Unknown method: {args.method}")
dataset_length = len(dataset) dataset_length = len(dataset)
print(f"Dataset size = {dataset_length}")
training_size = ceil(0.8 * dataset_length) training_size = ceil(0.8 * dataset_length)
print(f"training set size = {training_size}, validation set size {dataset_length - training_size}") print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
train_set, validate_set = torch.utils.data.random_split(dataset, train_set, validate_set = torch.utils.data.random_split(dataset,
[training_size, dataset_length - training_size]) [training_size, dataset_length - training_size])
@ -40,6 +43,7 @@ if __name__ == "__main__":
model = None model = None
if args.model_path is not None: if args.model_path is not None:
print("Loading the model...")
model = torch.load(args.model_path) model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer() trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()