feat: Context CLI arg

This commit is contained in:
Tibo De Peuter 2025-12-11 13:58:38 +01:00
parent cd74949b74
commit a4583d402b
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
8 changed files with 38 additions and 31 deletions

View file

@ -23,7 +23,8 @@ class Dataset(TorchDataset, ABC):
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1
size: int = -1,
context_length: int = 1024
):
"""
:param root: Path to the dataset root directory
@ -37,8 +38,11 @@ class Dataset(TorchDataset, ABC):
self.split = split
self.transform = transform
self.size = size
self.context_length = context_length
self.data = None
print(f"Context length: {self.context_length}")
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])

View file

@ -15,9 +15,10 @@ class EnWik9DataSet(Dataset):
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
size: int = -1,
context_length: int = 1024
):
super().__init__('enwik9', root, split, transform, size)
super().__init__('enwik9', root, split, transform, size, context_length)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
@ -26,9 +27,6 @@ class EnWik9DataSet(Dataset):
self.data = text_chunks['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset

View file

@ -18,18 +18,15 @@ class HumanReferenceGenomeDataset(Dataset):
split: str = "train",
transform: Callable = None,
size: int = -1,
context_length: int = 1024,
config: str = "6kbp",
):
super().__init__("human_reference_genome", root, split, transform, size)
super().__init__("human_reference_genome", root, split, transform, size, context_length)
print(f"Loading from HuggingFace (config: {config}, split: {split})")
ds = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
cache_dir=self.root, trust_remote_code=True)
# Your Dataset.process_data() expects a list[str]; use the 'sequence' field
self.data = ds["sequence"]
self.context_length = 2048
self.data = data["sequence"]
self.process_data()

View file

@ -12,17 +12,16 @@ class LoremIpsumDataset(Dataset):
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30
size: int = 2**30,
context_length: int = 1024
):
super().__init__('lorem_ipsum', root, split, transform, size)
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
_lorem = TextLorem()
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.size = size
self.context_length = 128
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)

View file

@ -19,9 +19,10 @@ class OpenGenomeDataset(Dataset):
split: str = 'train',
transform: Callable = None,
size: int = -1,
context_length: int = 1024,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
super().__init__('open_genome', root, split, transform, size, context_length)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
@ -29,9 +30,6 @@ class OpenGenomeDataset(Dataset):
self.data = data['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
print("Done initializing dataset")