This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../CNN-model/dataset_loaders/EnWik9.py

44 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Callable
import torch
from datasets import load_dataset
from .Dataset import Dataset
class EnWik9DataSet(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
"""
def __init__(self, root: str | None = None, transform: Callable | None = None):
super().__init__('enwik9', root, transform)
# HuggingFace dataset: string text
data = load_dataset("haukur/enwik9", cache_dir=self.root, split="train")
# Extract raw text
text = data["text"]
# Convert text (Python string) → bytes → tensor of ints 0255
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
byte_data = "".join(text).encode("utf-8", errors="replace")
self.data = torch.tensor(list(byte_data), dtype=torch.long)
# Model uses fixed 128-length context
self.context_length = 128
def __len__(self):
# number of sliding windows
return len(self.data) - self.context_length
def __getitem__(self, idx):
# context window
x = self.data[idx: idx + self.context_length]
# next byte target
y = self.data[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y