Merge branch 'main' into process

This commit is contained in:
Tibo De Peuter 2025-12-11 23:16:25 +01:00
commit 4c8d603092
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
14 changed files with 393 additions and 109 deletions

View file

@ -9,10 +9,14 @@ def parse_arguments():
help="Enable debug mode: smaller datasets, more information")
parser.add_argument("--verbose", "-v", action="store_true", required=False,
help="Enable verbose mode")
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
parser.add_argument("--device", required=False, help="Override the device to use")
dataparser = ArgumentParser(add_help=False)
dataparser.add_argument("--data-root", type=str, required=False)
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
dataparser.add_argument("--size", "-s", type=int, required=False,
help="Size of the subset of the dataset to use")
modelparser = ArgumentParser(add_help=False)
modelparser.add_argument("--model", "-m", type=str, required=False,
@ -21,6 +25,8 @@ def parse_arguments():
help="Filepath to the model to load")
modelparser.add_argument("--model-save-path", type=str, required=False,
help="Filepath to the model to save")
modelparser.add_argument("--context", type=int, required=False,
help="Context length to use")
fileparser = ArgumentParser(add_help=False)
fileparser.add_argument("--input-file", "-i", required=False, type=str)
@ -35,11 +41,11 @@ def parse_arguments():
train_parser.add_argument("--method",
choices=["fetch", "optuna", "full"], required=True,
help="Method to use for training")
train_parser.add_argument("--size", "-s", type=int, required=False,
help="Size of the subset of the dataset to use")
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
subparsers.add_parser("compress", parents=[modelparser, fileparser],
help="Compress a file")
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
subparsers.add_parser("decompress", parents=[modelparser, fileparser],
help="Decompress a file")
return parser.parse_args(), parser.print_help

View file

@ -23,7 +23,8 @@ class Dataset(TorchDataset, ABC):
root: str | None,
split: str = 'train',
transform: Callable = None,
size: int = -1
size: int = -1,
context_length: int = 1024
):
"""
:param root: Path to the dataset root directory
@ -37,8 +38,11 @@ class Dataset(TorchDataset, ABC):
self.split = split
self.transform = transform
self.size = size
self.context_length = context_length
self.data = None
print(f"Context length: {self.context_length}")
self.chunk_offsets: list[int] = []
self.bytes: bytes = bytes()
self.tensor: Tensor = torch.tensor([])

View file

@ -15,9 +15,10 @@ class EnWik9DataSet(Dataset):
root: str | None = None,
split: str = 'train',
transform: Callable | None = None,
size: int = -1
size: int = -1,
context_length: int = 1024
):
super().__init__('enwik9', root, split, transform, size)
super().__init__('enwik9', root, split, transform, size, context_length)
print(f"Loading from HuggingFace")
ft = Features({'text': Value('string')})
@ -26,9 +27,6 @@ class EnWik9DataSet(Dataset):
self.data = text_chunks['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
# Define splits manually, because they do not exist in the dataset

View file

@ -0,0 +1,45 @@
from typing import Callable
from datasets import load_dataset
from .Dataset import Dataset
class HumanReferenceGenomeDataset(Dataset):
"""
Hugging Face: https://huggingface.co/datasets/InstaDeepAI/human_reference_genome
:param split: 'train' | 'validation' | 'test'
:param config: '6kbp' | '12kbp' (chunk length in the HF builder config)
"""
def __init__(self,
root: str | None = None,
split: str = "train",
transform: Callable = None,
size: int = -1,
context_length: int = 1024,
config: str = "6kbp",
):
super().__init__("human_reference_genome", root, split, transform, size, context_length)
print(f"Loading from HuggingFace (config: {config}, split: {split})")
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
cache_dir=self.root, trust_remote_code=True)
self.data = data["sequence"]
self.process_data()
print("Done initializing dataset")
def __len__(self):
return self.chunk_offsets[-1] - self.context_length
def __getitem__(self, idx):
x = self.tensor[idx: idx + self.context_length]
y = self.tensor[idx + self.context_length]
if self.transform:
x = self.transform(x)
return x, y

View file

@ -12,17 +12,16 @@ class LoremIpsumDataset(Dataset):
root: str | None = None,
split: str = 'train',
transform: Callable = None,
size: int = 2**30
size: int = 2**30,
context_length: int = 1024
):
super().__init__('lorem_ipsum', root, split, transform, size)
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
_lorem = TextLorem()
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
self.size = size
self.context_length = 128
self.process_data()
split_point = ceil(self.chunk_offsets[-1] * 0.8)

View file

@ -19,9 +19,10 @@ class OpenGenomeDataset(Dataset):
split: str = 'train',
transform: Callable = None,
size: int = -1,
context_length: int = 1024,
stage: str = 'stage2'
):
super().__init__('open_genome', root, split, transform, size)
super().__init__('open_genome', root, split, transform, size, context_length)
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
ft = Features({'text': Value('string')})
@ -29,9 +30,6 @@ class OpenGenomeDataset(Dataset):
self.data = data['text']
self.size = size
# Model uses fixed 128-length context
self.context_length = 128
self.process_data()
print("Done initializing dataset")

View file

@ -1,10 +1,12 @@
from .Dataset import Dataset
from .EnWik9 import EnWik9DataSet
from .HumanReferenceGenomeDataset import HumanReferenceGenomeDataset
from .LoremIpsumDataset import LoremIpsumDataset
from .OpenGenomeDataset import OpenGenomeDataset
dataset_called: dict[str, type[Dataset]] = {
'enwik9': EnWik9DataSet,
'lorem_ipsum': LoremIpsumDataset,
'opengenome': OpenGenomeDataset
'opengenome': OpenGenomeDataset,
'humanreference': HumanReferenceGenomeDataset
}

View file

@ -14,10 +14,12 @@ def train(
data_root: str,
n_trials: int | None = None,
size: int | None = None,
context_length: int | None = None,
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None
model_out: str | None = None,
results_dir: str = 'results'
):
batch_size = 64
@ -38,6 +40,9 @@ def train(
if size:
dataset_common_args['size'] = size
if context_length:
dataset_common_args['context_length'] = context_length
print("Loading in the dataset...")
if dataset in dataset_called:
training_set = dataset_called[dataset](split='train', **dataset_common_args)
@ -54,7 +59,7 @@ def train(
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
print("Training")
best_model = trainer.execute(

View file

@ -6,8 +6,10 @@ from .trainer import Trainer
from ..models import Model
from ..utils import print_losses
class FullTrainer(Trainer):
def __init__(self, results_dir: str = 'results'):
self.results_dir = results_dir
def execute(
self,
model: Model,
@ -20,7 +22,8 @@ class FullTrainer(Trainer):
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
print_losses(train_loss, val_loss)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
device=device)
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
return model

View file

@ -1,3 +1,4 @@
import csv
from os import path
import matplotlib.pyplot as plt
@ -34,6 +35,15 @@ def print_losses(train_losses: list[float], validation_losses: list[float], file
print(f"Saving losses to {filename}...")
plt.savefig(filename)
# Also write to CSV file
with open(filename.replace(".png", ".csv"), "w") as f:
writer = csv.writer(f)
writer.writerow(["epoch", "train_loss", "validation_loss"])
for i in range(len(train_losses)):
writer.writerow([i, train_losses[i], validation_losses[i]])
print("Done")
def determine_device():
# NVIDIA GPUs (most HPC clusters)