changes to training + added autoencoder

This commit is contained in:
RobinMeersman 2025-12-13 17:53:01 +01:00
parent 6e591bb470
commit a4a41d190b
7 changed files with 91 additions and 55 deletions

View file

@ -5,15 +5,16 @@ from src.models import Model
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_dim): def __init__(self, data_length, channel_count, latent_dim):
super(Encoder, self).__init__() super(Encoder, self).__init__()
self._encoder = nn.Sequential(*[ self._encoder = nn.Sequential(*[
nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1), nn.Conv1d(1, channel_count, kernel_size=3, padding=1), # (hidden_size, L)
nn.BatchNorm1d(hidden_size), nn.BatchNorm1d(channel_count),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_size, 2 * hidden_size, stride=2, kernel_size=3, padding=1), nn.Conv1d(channel_count, 2 * channel_count, stride=2, kernel_size=3, padding=1), # (2 * hidden_size, L / 2)
nn.BatchNorm1d(2 * hidden_size), nn.BatchNorm1d(2 * channel_count),
nn.Linear(2 * hidden_size, latent_dim), nn.Flatten(), # 2 * hidden_size * L / 2
nn.Linear(2 * channel_count * data_length // 2, latent_dim),
nn.ReLU() nn.ReLU()
]) ])
@ -22,27 +23,28 @@ class Encoder(nn.Module):
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size): def __init__(self, latent_dim, channel_count, data_length):
super(Decoder, self).__init__() super(Decoder, self).__init__()
super._decoder = nn.Sequential(*[ self._decoder = nn.Sequential(*[
nn.Linear(input_size, 2 * hidden_size), nn.Linear(latent_dim, 2 * channel_count * data_length // 2),
nn.ReLU(), nn.ReLU(),
nn.BatchNorm1d(2 * hidden_size), nn.Unflatten(1, (2 * channel_count, data_length // 2)),
nn.ConvTranspose1d(2 * hidden_size, hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm1d(2 * channel_count),
nn.BatchNorm1d(hidden_size), nn.ConvTranspose1d(2 * channel_count, channel_count, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm1d(channel_count),
nn.ReLU(), nn.ReLU(),
nn.ConvTranspose1d(hidden_size, output_size, kernel_size=3, padding=1), nn.ConvTranspose1d(channel_count, 1, kernel_size=3, padding=1),
]) ])
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._decoder(x) return self._decoder(x)
class AutoEncoder(Model): class AutoEncoder(Model):
def __init__(self, input_size, hidden_size, latent_dim): def __init__(self, input_size, channel_count, latent_dim):
super().__init__(loss_function = nn.CrossEntropyLoss()) super().__init__(loss_function = nn.MSELoss())
self.encoder = Encoder(input_size, hidden_size, latent_dim) self.encoder = Encoder(input_size, channel_count, latent_dim)
self.decoder = Decoder(latent_dim, hidden_size, input_size) self.decoder = Decoder(latent_dim, channel_count, input_size)
def encode(self, x: torch.Tensor) -> torch.Tensor: def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x) return self.encoder(x)
@ -50,5 +52,11 @@ class AutoEncoder(Model):
def decode(self, x: torch.Tensor) -> torch.Tensor: def decode(self, x: torch.Tensor) -> torch.Tensor:
return self.decoder(x) return self.decoder(x)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.LongTensor) -> torch.Tensor:
return self.decode(self.encode(x)) x = x.float() / 255.0 # convert to floats
x = x.unsqueeze(1) # add channel dimension --> (B, 1, L)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded

View file

@ -18,10 +18,13 @@ class CNNPredictor(Model):
# 2. Convolutional feature extractor # 2. Convolutional feature extractor
self.conv_layers = nn.Sequential( self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2), nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(), nn.ReLU(),
) )

View file

@ -26,7 +26,7 @@ def train(
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided" assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name: if model_name:
print("Creating model") print(f"Creating model: {model_name}")
model = model_called[model_name] model = model_called[model_name]
else: else:
print("Loading model from disk") print("Loading model from disk")
@ -64,6 +64,7 @@ def train(
print("Training") print("Training")
best_model = trainer.execute( best_model = trainer.execute(
model=model, model=model,
context_length=context_length,
train_loader=training_loader, train_loader=training_loader,
validation_loader=validation_loader, validation_loader=validation_loader,
n_epochs=n_trials, n_epochs=n_trials,

View file

@ -13,6 +13,7 @@ class FullTrainer(Trainer):
def execute( def execute(
self, self,
model: Model, model: Model,
context_length: int,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int | None, n_epochs: int | None,

View file

@ -5,30 +5,37 @@ from torch.utils.data import DataLoader
from .train import train from .train import train
from .trainer import Trainer from .trainer import Trainer
from ..models import Model, CNNPredictor, ByteTransformer from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder
def create_model(trial: tr.Trial, model: nn.Module): def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
match model.__class__: if model_cls is CNNPredictor:
case CNNPredictor.__class__: return CNNPredictor(
return model( hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True), embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), vocab_size=256,
vocab_size=256, )
) if model_cls is ByteTransformer:
case ByteTransformer.__class__: nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 # d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead) return ByteTransformer(
return model( d_model=context_length,
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors nhead=nhead,
nhead=nhead, num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True), dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True), dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True), activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]), layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True), )
) if model_cls is AutoEncoder:
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
return AutoEncoder(
channel_count=channel_count,
latent_dim=latent_dim,
input_size=context_length,
)
return None return None
@ -36,10 +43,11 @@ def objective_function(
trial: tr.Trial, trial: tr.Trial,
training_loader: DataLoader, training_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
model: Model, model: type[Model],
context_length: int,
device: str device: str
): ):
model = create_model(trial, model).to(device) model = create_model(trial, model, context_length).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device) _, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
return min(validation_loss) return min(validation_loss)
@ -52,7 +60,8 @@ class OptunaTrainer(Trainer):
def execute( def execute(
self, self,
model: Model, model: type[Model],
context_length,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int, n_epochs: int,
@ -60,13 +69,19 @@ class OptunaTrainer(Trainer):
) -> nn.Module: ) -> nn.Module:
study = optuna.create_study(direction="minimize") study = optuna.create_study(direction="minimize")
study.optimize( study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, model, device), lambda trial: objective_function(trial, train_loader, validation_loader, model, context_length, device),
n_trials=self.n_trials n_trials=self.n_trials
) )
best_params = study.best_trial.params best_params = study.best_trial.params
best_model = model( if model is AutoEncoder:
**best_params best_model = AutoEncoder(
) input_size=context_length,
**best_params
)
elif model is CNNPredictor:
best_model = CNNPredictor(**best_params)
else:
raise ValueError(f"Unknown model type: {model}")
return best_model return best_model

View file

@ -4,7 +4,7 @@ import torch
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from ..models import ByteTransformer, Model from ..models import ByteTransformer, Model, AutoEncoder
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor: def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
@ -53,9 +53,12 @@ def train(
y = y.long().to(device) y = y.long().to(device)
optimizer.zero_grad() optimizer.zero_grad()
logits = _forward(model, x, device) pred = _forward(model, x, device)
loss = loss_fn(logits, y) if isinstance(model, AutoEncoder):
loss = loss_fn(pred, x.float() / 255.0)
else:
loss = loss_fn(pred, y)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -71,8 +74,12 @@ def train(
x = x.long().to(device) x = x.long().to(device)
y = y.long().to(device) y = y.long().to(device)
logits = _forward(model, x, device) pred = _forward(model, x, device)
loss = loss_fn(logits, y)
if isinstance(model, AutoEncoder):
loss = loss_fn(pred, x.float() / 255.0)
else:
loss = loss_fn(pred, y)
losses.append(loss.item()) losses.append(loss.item())
avg_loss = sum(losses) / len(losses) avg_loss = sum(losses) / len(losses)

View file

@ -10,7 +10,8 @@ class Trainer(ABC):
@abstractmethod @abstractmethod
def execute( def execute(
self, self,
model: nn.Module | None, model: nn.Module | type[nn.Module] | None,
context_length: int,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int | None, n_epochs: int | None,