fix: fixed model shapes + redit training loop

This commit is contained in:
Robin Meersman 2025-11-27 14:11:53 +01:00
parent ed44d5b283
commit eb4a014aa1
3 changed files with 68 additions and 56 deletions

View file

@ -1,45 +1,52 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor
from torch.nn.functional import softmax
class CausalConv1d(nn.Conv1d):
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
def forward(self, input: Tensor) -> Tensor:
return super().forward(input)[:, :, :input.size(-1)]
class CNNPredictor(nn.Module): class CNNPredictor(nn.Module):
def __init__( def __init__(
self, self,
vocab_size=256, vocab_size=256,
num_layers=3, embed_dim=64,
hidden_dim=128, hidden_channels=128,
kernel_size=3, ):
dropout_prob=0.1,
use_batchnorm=False
):
super().__init__() super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_dim)
layers = []
in_channels = hidden_dim
for _ in range(num_layers):
out_channels = hidden_dim
layers.append(CausalConv1d(in_channels, out_channels, kernel_size))
if use_batchnorm:
layers.append(nn.BatchNorm1d(out_channels))
layers.append(nn.ReLU())
layers.append(nn.Dropout(dropout_prob))
in_channels = out_channels
self.network = nn.Sequential(*layers) # 1. Embedding: maps bytes (0255) → vectors
self.output_layer = nn.Linear(hidden_dim, vocab_size) self.embed = nn.Embedding(vocab_size, embed_dim)
# 2. Convolutional feature extractor
self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_channels, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
nn.ReLU(),
)
# 3. Global pooling to collapse sequence length
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
# 4. Final classifier
self.fc = nn.Linear(hidden_channels, vocab_size) # → (B, 256)
def forward(self, x):
"""
x: LongTensor of shape (B, 128), values 0-255
"""
# embed: (B, 128, embed_dim)
x = self.embed(x)
# conv1d expects (B, C_in, L) → swap dims
x = x.transpose(1, 2) # (B, embed_dim, 128)
# apply CNN
x = self.conv_layers(x) # (B, hidden_channels, 128)
# global average pooling over sequence
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
# final classifier
logits = self.fc(x) # (B, 256)
return logits
def forward(self, x: torch.Tensor) -> torch.Tensor:
emdedding = self.embedding(x) # B, L, H
emdedding = emdedding.transpose(1, 2) # B, H, L
prediction = self.network(emdedding)
last_prediction = prediction[:, :, -1]
return self.output_layer(last_prediction)

View file

@ -12,19 +12,13 @@ from train import train
def create_model(trial: tr.Trial, vocab_size: int = 256): def create_model(trial: tr.Trial, vocab_size: int = 256):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True) hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7) embedding_dim = trial.suggest_int("embedding_dim", 64, 512, log=True)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor( return CNNPredictor(
vocab_size=vocab_size, vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim, hidden_dim=hidden_dim,
kernel_size=kernel_size, embed_dim=embedding_dim,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
) )

View file

@ -2,7 +2,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from typing import Callable from typing import Callable
@ -13,22 +12,28 @@ def train(
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100, epochs: int = 100,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
weight_decay: float = 1e-8 weight_decay: float = 1e-8,
device="cuda"
) -> tuple[list[float], list[float]]: ) -> tuple[list[float], list[float]]:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = [] avg_training_losses = []
avg_validation_losses = [] avg_validation_losses = []
for epoch in range(epochs): for epoch in range(epochs):
model.train() model.train()
total_loss = [] total_loss = []
for data in tqdm(training_loader): for x, y in tqdm(training_loader):
x = x.long().to(device) # important for Embedding
y = y.long().to(device) # must be (B,) for CE
optimizer.zero_grad() optimizer.zero_grad()
logits = model(x) # (B, 256)
x_hat = model(data) loss = loss_fn(logits, y)
loss = loss_fn(x_hat, data)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -36,15 +41,21 @@ def train(
avg_training_losses.append(sum(total_loss) / len(total_loss)) avg_training_losses.append(sum(total_loss) / len(total_loss))
# ----- validation -----
model.eval()
with torch.no_grad(): with torch.no_grad():
losses = [] losses = []
for data in validation_loader: for x, y in validation_loader:
x_hat = model(data) x = x.long().to(device)
loss = loss_fn(x_hat, data) y = y.long().to(device)
logits = model(x)
loss = loss_fn(logits, y)
losses.append(loss.item()) losses.append(loss.item())
avg_loss = sum(losses) / len(losses) avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss) avg_validation_losses.append(avg_loss)
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
tqdm.write(f"epoch: {epoch + 1}, avg val loss = {avg_loss:.4f}")
return avg_training_losses, avg_validation_losses return avg_training_losses, avg_validation_losses