Merge branch 'main' into process
This commit is contained in:
commit
4c8d603092
14 changed files with 393 additions and 109 deletions
14
src/args.py
14
src/args.py
|
|
@ -9,10 +9,14 @@ def parse_arguments():
|
|||
help="Enable debug mode: smaller datasets, more information")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||
help="Enable verbose mode")
|
||||
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
|
||||
parser.add_argument("--device", required=False, help="Override the device to use")
|
||||
|
||||
dataparser = ArgumentParser(add_help=False)
|
||||
dataparser.add_argument("--data-root", type=str, required=False)
|
||||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||
dataparser.add_argument("--size", "-s", type=int, required=False,
|
||||
help="Size of the subset of the dataset to use")
|
||||
|
||||
modelparser = ArgumentParser(add_help=False)
|
||||
modelparser.add_argument("--model", "-m", type=str, required=False,
|
||||
|
|
@ -21,6 +25,8 @@ def parse_arguments():
|
|||
help="Filepath to the model to load")
|
||||
modelparser.add_argument("--model-save-path", type=str, required=False,
|
||||
help="Filepath to the model to save")
|
||||
modelparser.add_argument("--context", type=int, required=False,
|
||||
help="Context length to use")
|
||||
|
||||
fileparser = ArgumentParser(add_help=False)
|
||||
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
||||
|
|
@ -35,11 +41,11 @@ def parse_arguments():
|
|||
train_parser.add_argument("--method",
|
||||
choices=["fetch", "optuna", "full"], required=True,
|
||||
help="Method to use for training")
|
||||
train_parser.add_argument("--size", "-s", type=int, required=False,
|
||||
help="Size of the subset of the dataset to use")
|
||||
|
||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||
subparsers.add_parser("compress", parents=[modelparser, fileparser],
|
||||
help="Compress a file")
|
||||
|
||||
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
|
||||
subparsers.add_parser("decompress", parents=[modelparser, fileparser],
|
||||
help="Decompress a file")
|
||||
|
||||
return parser.parse_args(), parser.print_help
|
||||
|
|
|
|||
|
|
@ -23,7 +23,8 @@ class Dataset(TorchDataset, ABC):
|
|||
root: str | None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1
|
||||
size: int = -1,
|
||||
context_length: int = 1024
|
||||
):
|
||||
"""
|
||||
:param root: Path to the dataset root directory
|
||||
|
|
@ -37,8 +38,11 @@ class Dataset(TorchDataset, ABC):
|
|||
self.split = split
|
||||
self.transform = transform
|
||||
self.size = size
|
||||
self.context_length = context_length
|
||||
self.data = None
|
||||
|
||||
print(f"Context length: {self.context_length}")
|
||||
|
||||
self.chunk_offsets: list[int] = []
|
||||
self.bytes: bytes = bytes()
|
||||
self.tensor: Tensor = torch.tensor([])
|
||||
|
|
|
|||
|
|
@ -15,9 +15,10 @@ class EnWik9DataSet(Dataset):
|
|||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable | None = None,
|
||||
size: int = -1
|
||||
size: int = -1,
|
||||
context_length: int = 1024
|
||||
):
|
||||
super().__init__('enwik9', root, split, transform, size)
|
||||
super().__init__('enwik9', root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace")
|
||||
ft = Features({'text': Value('string')})
|
||||
|
|
@ -26,9 +27,6 @@ class EnWik9DataSet(Dataset):
|
|||
self.data = text_chunks['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
# Define splits manually, because they do not exist in the dataset
|
||||
|
|
|
|||
45
src/dataset_loaders/HumanReferenceGenomeDataset.py
Normal file
45
src/dataset_loaders/HumanReferenceGenomeDataset.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class HumanReferenceGenomeDataset(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/InstaDeepAI/human_reference_genome
|
||||
|
||||
:param split: 'train' | 'validation' | 'test'
|
||||
:param config: '6kbp' | '12kbp' (chunk length in the HF builder config)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = "train",
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024,
|
||||
config: str = "6kbp",
|
||||
):
|
||||
super().__init__("human_reference_genome", root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace (config: {config}, split: {split})")
|
||||
data = load_dataset("InstaDeepAI/human_reference_genome", config, split=split,
|
||||
cache_dir=self.root, trust_remote_code=True)
|
||||
self.data = data["sequence"]
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.tensor[idx: idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
|
@ -12,17 +12,16 @@ class LoremIpsumDataset(Dataset):
|
|||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30
|
||||
size: int = 2**30,
|
||||
context_length: int = 1024
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
||||
super().__init__('lorem_ipsum', root, split, transform, size, context_length)
|
||||
|
||||
_lorem = TextLorem()
|
||||
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
|
|
|||
|
|
@ -19,9 +19,10 @@ class OpenGenomeDataset(Dataset):
|
|||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
context_length: int = 1024,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size)
|
||||
super().__init__('open_genome', root, split, transform, size, context_length)
|
||||
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
|
|
@ -29,9 +30,6 @@ class OpenGenomeDataset(Dataset):
|
|||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
from .Dataset import Dataset
|
||||
from .EnWik9 import EnWik9DataSet
|
||||
from .HumanReferenceGenomeDataset import HumanReferenceGenomeDataset
|
||||
from .LoremIpsumDataset import LoremIpsumDataset
|
||||
from .OpenGenomeDataset import OpenGenomeDataset
|
||||
|
||||
dataset_called: dict[str, type[Dataset]] = {
|
||||
'enwik9': EnWik9DataSet,
|
||||
'lorem_ipsum': LoremIpsumDataset,
|
||||
'opengenome': OpenGenomeDataset
|
||||
'opengenome': OpenGenomeDataset,
|
||||
'humanreference': HumanReferenceGenomeDataset
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,10 +14,12 @@ def train(
|
|||
data_root: str,
|
||||
n_trials: int | None = None,
|
||||
size: int | None = None,
|
||||
context_length: int | None = None,
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
model_out: str | None = None,
|
||||
results_dir: str = 'results'
|
||||
):
|
||||
batch_size = 64
|
||||
|
||||
|
|
@ -38,6 +40,9 @@ def train(
|
|||
if size:
|
||||
dataset_common_args['size'] = size
|
||||
|
||||
if context_length:
|
||||
dataset_common_args['context_length'] = context_length
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if dataset in dataset_called:
|
||||
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
||||
|
|
@ -54,7 +59,7 @@ def train(
|
|||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
|
||||
|
||||
print("Training")
|
||||
best_model = trainer.execute(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from .trainer import Trainer
|
|||
from ..models import Model
|
||||
from ..utils import print_losses
|
||||
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def __init__(self, results_dir: str = 'results'):
|
||||
self.results_dir = results_dir
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
|
|
@ -20,7 +22,8 @@ class FullTrainer(Trainer):
|
|||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
|
||||
print_losses(train_loss, val_loss)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
||||
device=device)
|
||||
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
|
||||
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import csv
|
||||
from os import path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
|
@ -34,6 +35,15 @@ def print_losses(train_losses: list[float], validation_losses: list[float], file
|
|||
print(f"Saving losses to {filename}...")
|
||||
plt.savefig(filename)
|
||||
|
||||
# Also write to CSV file
|
||||
with open(filename.replace(".png", ".csv"), "w") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["epoch", "train_loss", "validation_loss"])
|
||||
for i in range(len(train_losses)):
|
||||
writer.writerow([i, train_losses[i], validation_losses[i]])
|
||||
|
||||
print("Done")
|
||||
|
||||
|
||||
def determine_device():
|
||||
# NVIDIA GPUs (most HPC clusters)
|
||||
|
|
|
|||
Reference in a new issue