Merge pull request #3 from ML/model-rewrite

fix: fixed model shapes + redit training loop
This commit is contained in:
Robin Meersman 2025-11-27 19:27:31 +01:00 committed by GitHub Enterprise
commit 2ab4abdf93
3 changed files with 68 additions and 56 deletions

View file

@ -12,19 +12,13 @@ from .train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
embedding_dim = trial.suggest_int("embedding_dim", 64, 512, log=True)
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
embed_dim=embedding_dim,
)

View file

@ -2,7 +2,6 @@ import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
@ -13,22 +12,28 @@ def train(
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8
weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for data in tqdm(training_loader):
for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
optimizer.zero_grad()
x_hat = model(data)
loss = loss_fn(x_hat, data)
logits = model(x) # (B, 256)
loss = loss_fn(logits, y)
loss.backward()
optimizer.step()
@ -36,15 +41,21 @@ def train(
avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad():
losses = []
for data in validation_loader:
x_hat = model(data)
loss = loss_fn(x_hat, data)
for x, y in validation_loader:
x = x.long().to(device)
y = y.long().to(device)
logits = model(x)
loss = loss_fn(logits, y)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
tqdm.write(f"epoch: {epoch + 1}, avg val loss = {avg_loss:.4f}")
return avg_training_losses, avg_validation_losses