80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
from pathlib import Path
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from src.dataset_loaders import dataset_called
|
|
from src.models import model_called
|
|
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
|
|
|
|
|
def train(
|
|
device,
|
|
dataset: str,
|
|
data_root: str,
|
|
n_trials: int | None = None,
|
|
size: int | None = None,
|
|
context_length: int | None = None,
|
|
method: str = 'optuna',
|
|
model_name: str | None = None,
|
|
model_path: str | None = None,
|
|
model_out: str | None = None,
|
|
results_dir: str = 'results'
|
|
):
|
|
batch_size = 64
|
|
|
|
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
|
|
|
if model_name:
|
|
print(f"Creating model: {model_name}")
|
|
model = model_called[model_name]
|
|
else:
|
|
print("Loading model from disk")
|
|
model = torch.load(model_path, weights_only=False)
|
|
|
|
dataset_common_args = {
|
|
'root': data_root,
|
|
'transform': lambda x: x.to(device),
|
|
}
|
|
|
|
if size:
|
|
dataset_common_args['size'] = size
|
|
|
|
if context_length:
|
|
dataset_common_args['context_length'] = context_length
|
|
|
|
print("Loading in the dataset...")
|
|
if dataset in dataset_called:
|
|
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
|
validate_set = dataset_called[dataset](split='validation', **dataset_common_args)
|
|
else:
|
|
# TODO Allow to import arbitrary files
|
|
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
|
|
|
if method == 'fetch':
|
|
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
|
exit(0)
|
|
|
|
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
|
|
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
|
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
|
|
|
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
|
|
|
|
print("Training")
|
|
best_model = trainer.execute(
|
|
model=model,
|
|
context_length=context_length,
|
|
train_loader=training_loader,
|
|
validation_loader=validation_loader,
|
|
n_epochs=n_trials,
|
|
device=device
|
|
)
|
|
|
|
print("Saving model...")
|
|
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
|
|
# Make sure path exists
|
|
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
|
torch.save(best_model, f)
|
|
print(f"Saved model to '{f}'")
|
|
|