This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../src/train.py
2025-12-13 17:53:01 +01:00

80 lines
2.6 KiB
Python

from pathlib import Path
import torch
from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called
from src.models import model_called
from src.trainers import OptunaTrainer, Trainer, FullTrainer
def train(
device,
dataset: str,
data_root: str,
n_trials: int | None = None,
size: int | None = None,
context_length: int | None = None,
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None,
results_dir: str = 'results'
):
batch_size = 64
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name:
print(f"Creating model: {model_name}")
model = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path, weights_only=False)
dataset_common_args = {
'root': data_root,
'transform': lambda x: x.to(device),
}
if size:
dataset_common_args['size'] = size
if context_length:
dataset_common_args['context_length'] = context_length
print("Loading in the dataset...")
if dataset in dataset_called:
training_set = dataset_called[dataset](split='train', **dataset_common_args)
validate_set = dataset_called[dataset](split='validation', **dataset_common_args)
else:
# TODO Allow to import arbitrary files
raise NotImplementedError(f"Importing external datasets is not implemented yet")
if method == 'fetch':
# TODO More to earlier in chain, because now everything is converted into tensors as well?
exit(0)
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
print("Training")
best_model = trainer.execute(
model=model,
context_length=context_length,
train_loader=training_loader,
validation_loader=validation_loader,
n_epochs=n_trials,
device=device
)
print("Saving model...")
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
# Make sure path exists
Path(f).parent.mkdir(parents=True, exist_ok=True)
torch.save(best_model, f)
print(f"Saved model to '{f}'")