feat: Results dir

This commit is contained in:
Tibo De Peuter 2025-12-11 15:38:43 +01:00
parent a4583d402b
commit 97e84a97db
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
4 changed files with 12 additions and 6 deletions

View file

@ -32,7 +32,8 @@ def main():
model_name=args.model,
model_path=args.model_load_path,
model_out=args.model_save_path,
context_length=args.context
context_length=args.context,
results_dir=args.results
)
case 'compress':

View file

@ -9,6 +9,7 @@ def parse_arguments():
help="Enable debug mode: smaller datasets, more information")
parser.add_argument("--verbose", "-v", action="store_true", required=False,
help="Enable verbose mode")
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
dataparser = ArgumentParser(add_help=False)
dataparser.add_argument("--data-root", type=str, required=False)

View file

@ -18,7 +18,8 @@ def train(
method: str = 'optuna',
model_name: str | None = None,
model_path: str | None = None,
model_out: str | None = None
model_out: str | None = None,
results_dir: str = 'results'
):
batch_size = 64
@ -58,7 +59,7 @@ def train(
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
print("Training")
best_model = trainer.execute(

View file

@ -6,8 +6,10 @@ from .trainer import Trainer
from ..models import Model
from ..utils import print_losses
class FullTrainer(Trainer):
def __init__(self, results_dir: str = 'results'):
self.results_dir = results_dir
def execute(
self,
model: Model,
@ -20,7 +22,8 @@ class FullTrainer(Trainer):
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
print_losses(train_loss, val_loss)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
device=device)
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
return model