from pathlib import Path import torch from torch.utils.data import DataLoader from src.dataset_loaders import dataset_called from src.models import model_called from src.trainers import OptunaTrainer, Trainer, FullTrainer def train( device, dataset: str, data_root: str, n_trials: int | None = None, size: int | None = None, context_length: int | None = None, method: str = 'optuna', model_name: str | None = None, model_path: str | None = None, model_out: str | None = None, results_dir: str = 'results' ): batch_size = 64 assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided" if model_name: print(f"Creating model: {model_name}") model = model_called[model_name] else: print("Loading model from disk") model = torch.load(model_path, weights_only=False) dataset_common_args = { 'root': data_root, 'transform': lambda x: x.to(device), } if size: dataset_common_args['size'] = size if context_length: dataset_common_args['context_length'] = context_length print("Loading in the dataset...") if dataset in dataset_called: training_set = dataset_called[dataset](split='train', **dataset_common_args) validate_set = dataset_called[dataset](split='validation', **dataset_common_args) else: # TODO Allow to import arbitrary files raise NotImplementedError(f"Importing external datasets is not implemented yet") if method == 'fetch': # TODO More to earlier in chain, because now everything is converted into tensors as well? exit(0) print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}") training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True) validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False) trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir) print("Training") best_model = trainer.execute( model=model, context_length=context_length, train_loader=training_loader, validation_loader=validation_loader, n_epochs=n_trials, device=device ) print("Saving model...") f = model_out or f"saved_models/{model.__class__.__name__}.pt" # Make sure path exists Path(f).parent.mkdir(parents=True, exist_ok=True) torch.save(best_model, f) print(f"Saved model to '{f}'")