feat: Results dir
This commit is contained in:
parent
a4583d402b
commit
97e84a97db
4 changed files with 12 additions and 6 deletions
3
main.py
3
main.py
|
|
@ -32,7 +32,8 @@ def main():
|
||||||
model_name=args.model,
|
model_name=args.model,
|
||||||
model_path=args.model_load_path,
|
model_path=args.model_load_path,
|
||||||
model_out=args.model_save_path,
|
model_out=args.model_save_path,
|
||||||
context_length=args.context
|
context_length=args.context,
|
||||||
|
results_dir=args.results
|
||||||
)
|
)
|
||||||
|
|
||||||
case 'compress':
|
case 'compress':
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,7 @@ def parse_arguments():
|
||||||
help="Enable debug mode: smaller datasets, more information")
|
help="Enable debug mode: smaller datasets, more information")
|
||||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||||
help="Enable verbose mode")
|
help="Enable verbose mode")
|
||||||
|
parser.add_argument("--results", type=str, required=True, help="path to save graphs to")
|
||||||
|
|
||||||
dataparser = ArgumentParser(add_help=False)
|
dataparser = ArgumentParser(add_help=False)
|
||||||
dataparser.add_argument("--data-root", type=str, required=False)
|
dataparser.add_argument("--data-root", type=str, required=False)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,8 @@ def train(
|
||||||
method: str = 'optuna',
|
method: str = 'optuna',
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
model_path: str | None = None,
|
model_path: str | None = None,
|
||||||
model_out: str | None = None
|
model_out: str | None = None,
|
||||||
|
results_dir: str = 'results'
|
||||||
):
|
):
|
||||||
batch_size = 64
|
batch_size = 64
|
||||||
|
|
||||||
|
|
@ -58,7 +59,7 @@ def train(
|
||||||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||||
|
|
||||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer(results_dir=results_dir)
|
||||||
|
|
||||||
print("Training")
|
print("Training")
|
||||||
best_model = trainer.execute(
|
best_model = trainer.execute(
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,10 @@ from .trainer import Trainer
|
||||||
from ..models import Model
|
from ..models import Model
|
||||||
from ..utils import print_losses
|
from ..utils import print_losses
|
||||||
|
|
||||||
|
|
||||||
class FullTrainer(Trainer):
|
class FullTrainer(Trainer):
|
||||||
|
def __init__(self, results_dir: str = 'results'):
|
||||||
|
self.results_dir = results_dir
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
|
|
@ -20,7 +22,8 @@ class FullTrainer(Trainer):
|
||||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
|
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
||||||
print_losses(train_loss, val_loss)
|
device=device)
|
||||||
|
print_losses(train_loss, val_loss, filename=f"{self.results_dir}/{model.__class__.__name__}-losses.png")
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
||||||
Reference in a new issue