This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../dataset_loaders/EnWik9.py
Robin Meersman 73d1742cbd code cleanup
2025-11-30 19:21:29 +01:00

43 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datasets import load_dataset
from torch.utils.data import Dataset
import torch
from os.path import curdir, join
from typing import Callable
class EnWik9DataSet(Dataset):
def __init__(self, root: str = "data", transform: Callable | None = None):
super().__init__()
self.transform = transform
# HuggingFace dataset: string text
path = join(curdir, root)
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
# Extract raw text
text = data["text"]
# Convert text (Python string) → bytes → tensor of ints 0255
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
byte_data = "".join(text).encode("utf-8", errors="replace")
self.data = torch.tensor(list(byte_data), dtype=torch.long)
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self):
# number of sliding windows
return len(self.data) - self.context_length
def __getitem__(self, idx):
# context window
x = self.data[idx : idx + self.context_length]
# next byte target
y = self.data[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y