fix: fixed model shapes + redit training loop
This commit is contained in:
parent
ed44d5b283
commit
eb4a014aa1
3 changed files with 68 additions and 56 deletions
|
|
@ -1,45 +1,52 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn.functional import softmax
|
|
||||||
|
|
||||||
|
|
||||||
class CausalConv1d(nn.Conv1d):
|
|
||||||
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
|
|
||||||
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
|
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
|
||||||
return super().forward(input)[:, :, :input.size(-1)]
|
|
||||||
|
|
||||||
class CNNPredictor(nn.Module):
|
class CNNPredictor(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
num_layers=3,
|
embed_dim=64,
|
||||||
hidden_dim=128,
|
hidden_channels=128,
|
||||||
kernel_size=3,
|
):
|
||||||
dropout_prob=0.1,
|
|
||||||
use_batchnorm=False
|
|
||||||
):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(vocab_size, hidden_dim)
|
|
||||||
layers = []
|
|
||||||
in_channels = hidden_dim
|
|
||||||
for _ in range(num_layers):
|
|
||||||
out_channels = hidden_dim
|
|
||||||
layers.append(CausalConv1d(in_channels, out_channels, kernel_size))
|
|
||||||
if use_batchnorm:
|
|
||||||
layers.append(nn.BatchNorm1d(out_channels))
|
|
||||||
layers.append(nn.ReLU())
|
|
||||||
layers.append(nn.Dropout(dropout_prob))
|
|
||||||
in_channels = out_channels
|
|
||||||
|
|
||||||
self.network = nn.Sequential(*layers)
|
# 1. Embedding: maps bytes (0–255) → vectors
|
||||||
self.output_layer = nn.Linear(hidden_dim, vocab_size)
|
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# 2. Convolutional feature extractor
|
||||||
|
self.conv_layers = nn.Sequential(
|
||||||
|
nn.Conv1d(embed_dim, hidden_channels, kernel_size=5, padding=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv1d(hidden_channels, hidden_channels, kernel_size=5, padding=2),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Global pooling to collapse sequence length
|
||||||
|
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
||||||
|
|
||||||
|
# 4. Final classifier
|
||||||
|
self.fc = nn.Linear(hidden_channels, vocab_size) # → (B, 256)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
x: LongTensor of shape (B, 128), values 0-255
|
||||||
|
"""
|
||||||
|
# embed: (B, 128, embed_dim)
|
||||||
|
x = self.embed(x)
|
||||||
|
|
||||||
|
# conv1d expects (B, C_in, L) → swap dims
|
||||||
|
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
||||||
|
|
||||||
|
# apply CNN
|
||||||
|
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
||||||
|
|
||||||
|
# global average pooling over sequence
|
||||||
|
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
||||||
|
|
||||||
|
# final classifier
|
||||||
|
logits = self.fc(x) # (B, 256)
|
||||||
|
return logits
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
emdedding = self.embedding(x) # B, L, H
|
|
||||||
emdedding = emdedding.transpose(1, 2) # B, H, L
|
|
||||||
prediction = self.network(emdedding)
|
|
||||||
last_prediction = prediction[:, :, -1]
|
|
||||||
return self.output_layer(last_prediction)
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,19 +12,13 @@ from train import train
|
||||||
|
|
||||||
|
|
||||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||||
num_layers = trial.suggest_int("num_layers", 1, 6)
|
|
||||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||||
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
embedding_dim = trial.suggest_int("embedding_dim", 64, 512, log=True)
|
||||||
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
|
||||||
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
|
||||||
|
|
||||||
return CNNPredictor(
|
return CNNPredictor(
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
num_layers=num_layers,
|
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
kernel_size=kernel_size,
|
embed_dim=embedding_dim,
|
||||||
dropout_prob=dropout_prob,
|
|
||||||
use_batchnorm=use_batchnorm
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -13,22 +12,28 @@ def train(
|
||||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
epochs: int = 100,
|
epochs: int = 100,
|
||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
weight_decay: float = 1e-8
|
weight_decay: float = 1e-8,
|
||||||
|
device="cuda"
|
||||||
) -> tuple[list[float], list[float]]:
|
) -> tuple[list[float], list[float]]:
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
|
||||||
avg_training_losses = []
|
avg_training_losses = []
|
||||||
avg_validation_losses = []
|
avg_validation_losses = []
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = []
|
total_loss = []
|
||||||
|
|
||||||
for data in tqdm(training_loader):
|
for x, y in tqdm(training_loader):
|
||||||
|
x = x.long().to(device) # important for Embedding
|
||||||
|
y = y.long().to(device) # must be (B,) for CE
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
logits = model(x) # (B, 256)
|
||||||
x_hat = model(data)
|
loss = loss_fn(logits, y)
|
||||||
|
|
||||||
loss = loss_fn(x_hat, data)
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
@ -36,15 +41,21 @@ def train(
|
||||||
|
|
||||||
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
||||||
|
|
||||||
|
# ----- validation -----
|
||||||
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
losses = []
|
losses = []
|
||||||
for data in validation_loader:
|
for x, y in validation_loader:
|
||||||
x_hat = model(data)
|
x = x.long().to(device)
|
||||||
loss = loss_fn(x_hat, data)
|
y = y.long().to(device)
|
||||||
|
|
||||||
|
logits = model(x)
|
||||||
|
loss = loss_fn(logits, y)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
avg_loss = sum(losses) / len(losses)
|
avg_loss = sum(losses) / len(losses)
|
||||||
avg_validation_losses.append(avg_loss)
|
avg_validation_losses.append(avg_loss)
|
||||||
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
|
|
||||||
|
tqdm.write(f"epoch: {epoch + 1}, avg val loss = {avg_loss:.4f}")
|
||||||
|
|
||||||
return avg_training_losses, avg_validation_losses
|
return avg_training_losses, avg_validation_losses
|
||||||
|
|
||||||
|
|
|
||||||
Reference in a new issue