feat: Results dir
This commit is contained in:
parent
a4583d402b
commit
97e84a97db
4 changed files with 12 additions and 6 deletions
3
main.py
3
main.py
|
|
@ -32,7 +32,8 @@ def main():
|
|||
model_name=args.model,
|
||||
model_path=args.model_load_path,
|
||||
model_out=args.model_save_path,
|
||||
context_length=args.context
|
||||
context_length=args.context,
|
||||
results_dir=args.results
|
||||
)
|
||||
|
||||
case 'compress':
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ def parse_arguments():
|
|||
help="Enable debug mode: smaller datasets, more information")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||
help="Enable verbose mode")
|
||||
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
|
||||
|
||||
dataparser = ArgumentParser(add_help=False)
|
||||
dataparser.add_argument("--data-root", type=str, required=False)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@ def train(
|
|||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
model_out: str | None = None,
|
||||
results_dir: str = 'results'
|
||||
):
|
||||
batch_size = 64
|
||||
|
||||
|
|
@ -58,7 +59,7 @@ def train(
|
|||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
|
||||
|
||||
print("Training")
|
||||
best_model = trainer.execute(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from .trainer import Trainer
|
|||
from ..models import Model
|
||||
from ..utils import print_losses
|
||||
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def __init__(self, results_dir: str = 'results'):
|
||||
self.results_dir = results_dir
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
|
|
@ -20,7 +22,8 @@ class FullTrainer(Trainer):
|
|||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
|
||||
print_losses(train_loss, val_loss)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
||||
device=device)
|
||||
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
|
||||
|
||||
return model
|
||||
|
|
|
|||
Reference in a new issue