Merge pull request #13 from ML/autoencoder

Autoencoder
This commit is contained in:
Robin Meersman 2025-12-13 18:04:09 +01:00 committed by GitHub Enterprise
commit 0ab495165f
9 changed files with 130 additions and 37 deletions

View file

@ -1,8 +1,10 @@
from .Model import Model from .Model import Model
from .autoencoder import AutoEncoder
from .cnn import CNNPredictor from .cnn import CNNPredictor
from .transformer import ByteTransformer from .transformer import ByteTransformer
model_called: dict[str, type[Model]] = { model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor, 'cnn': CNNPredictor,
'transformer': ByteTransformer 'transformer': ByteTransformer,
'autoencoder': AutoEncoder
} }

View file

@ -0,0 +1 @@
from .autoencoder import AutoEncoder

View file

@ -0,0 +1,62 @@
import torch
import torch.nn as nn
from src.models import Model
class Encoder(nn.Module):
def __init__(self, data_length, channel_count, latent_dim):
super(Encoder, self).__init__()
self._encoder = nn.Sequential(*[
nn.Conv1d(1, channel_count, kernel_size=3, padding=1), # (hidden_size, L)
nn.BatchNorm1d(channel_count),
nn.ReLU(),
nn.Conv1d(channel_count, 2 * channel_count, stride=2, kernel_size=3, padding=1), # (2 * hidden_size, L / 2)
nn.BatchNorm1d(2 * channel_count),
nn.Flatten(), # 2 * hidden_size * L / 2
nn.Linear(2 * channel_count * data_length // 2, latent_dim),
nn.ReLU()
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._encoder(x)
class Decoder(nn.Module):
def __init__(self, latent_dim, channel_count, data_length):
super(Decoder, self).__init__()
self._decoder = nn.Sequential(*[
nn.Linear(latent_dim, 2 * channel_count * data_length // 2),
nn.ReLU(),
nn.Unflatten(1, (2 * channel_count, data_length // 2)),
nn.BatchNorm1d(2 * channel_count),
nn.ConvTranspose1d(2 * channel_count, channel_count, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm1d(channel_count),
nn.ReLU(),
nn.ConvTranspose1d(channel_count, 1, kernel_size=3, padding=1),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._decoder(x)
class AutoEncoder(Model):
def __init__(self, input_size, channel_count, latent_dim):
super().__init__(loss_function = nn.MSELoss())
self.encoder = Encoder(input_size, channel_count, latent_dim)
self.decoder = Decoder(latent_dim, channel_count, input_size)
def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor:
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor:
x = x.float() / 255.0 # convert to floats
x = x.unsqueeze(1) # add channel dimension --> (B, 1, L)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded

View file

@ -18,10 +18,13 @@ class CNNPredictor(Model):
# 2. Convolutional feature extractor # 2. Convolutional feature extractor
self.conv_layers = nn.Sequential( self.conv_layers = nn.Sequential(
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2), nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(), nn.ReLU(),
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2), nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(), nn.ReLU(),
) )

View file

@ -26,7 +26,7 @@ def train(
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided" assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
if model_name: if model_name:
print("Creating model") print(f"Creating model: {model_name}")
model = model_called[model_name] model = model_called[model_name]
else: else:
print("Loading model from disk") print("Loading model from disk")
@ -64,6 +64,7 @@ def train(
print("Training") print("Training")
best_model = trainer.execute( best_model = trainer.execute(
model=model, model=model,
context_length=context_length,
train_loader=training_loader, train_loader=training_loader,
validation_loader=validation_loader, validation_loader=validation_loader,
n_epochs=n_trials, n_epochs=n_trials,

View file

@ -13,6 +13,7 @@ class FullTrainer(Trainer):
def execute( def execute(
self, self,
model: Model, model: Model,
context_length: int,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int | None, n_epochs: int | None,

View file

@ -5,22 +5,21 @@ from torch.utils.data import DataLoader
from .train import train from .train import train
from .trainer import Trainer from .trainer import Trainer
from ..models import Model, CNNPredictor, ByteTransformer from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder
def create_model(trial: tr.Trial, model: nn.Module): def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
match model.__class__: if model_cls is CNNPredictor:
case CNNPredictor.__class__: return CNNPredictor(
return model(
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True), hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256, vocab_size=256,
) )
case ByteTransformer.__class__: if model_cls is ByteTransformer:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead) # d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return model( return ByteTransformer(
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors d_model=context_length,
nhead=nhead, nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True), num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
@ -29,6 +28,14 @@ def create_model(trial: tr.Trial, model: nn.Module):
activation=trial.suggest_categorical("activation", ["relu", "gelu"]), activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True), layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
) )
if model_cls is AutoEncoder:
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
return AutoEncoder(
channel_count=channel_count,
latent_dim=latent_dim,
input_size=context_length,
)
return None return None
@ -36,10 +43,11 @@ def objective_function(
trial: tr.Trial, trial: tr.Trial,
training_loader: DataLoader, training_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
model: Model, model: type[Model],
context_length: int,
device: str device: str
): ):
model = create_model(trial, model).to(device) model = create_model(trial, model, context_length).to(device)
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device) _, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
return min(validation_loss) return min(validation_loss)
@ -52,7 +60,8 @@ class OptunaTrainer(Trainer):
def execute( def execute(
self, self,
model: Model, model: type[Model],
context_length,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int, n_epochs: int,
@ -60,13 +69,19 @@ class OptunaTrainer(Trainer):
) -> nn.Module: ) -> nn.Module:
study = optuna.create_study(direction="minimize") study = optuna.create_study(direction="minimize")
study.optimize( study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, model, device), lambda trial: objective_function(trial, train_loader, validation_loader, model, context_length, device),
n_trials=self.n_trials n_trials=self.n_trials
) )
best_params = study.best_trial.params best_params = study.best_trial.params
best_model = model( if model is AutoEncoder:
best_model = AutoEncoder(
input_size=context_length,
**best_params **best_params
) )
elif model is CNNPredictor:
best_model = CNNPredictor(**best_params)
else:
raise ValueError(f"Unknown model type: {model}")
return best_model return best_model

View file

@ -4,7 +4,7 @@ import torch
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from ..models import ByteTransformer, Model from ..models import ByteTransformer, Model, AutoEncoder
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor: def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
@ -53,9 +53,12 @@ def train(
y = y.long().to(device) y = y.long().to(device)
optimizer.zero_grad() optimizer.zero_grad()
logits = _forward(model, x, device) pred = _forward(model, x, device)
loss = loss_fn(logits, y) if isinstance(model, AutoEncoder):
loss = loss_fn(pred, x.float() / 255.0)
else:
loss = loss_fn(pred, y)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -71,8 +74,12 @@ def train(
x = x.long().to(device) x = x.long().to(device)
y = y.long().to(device) y = y.long().to(device)
logits = _forward(model, x, device) pred = _forward(model, x, device)
loss = loss_fn(logits, y)
if isinstance(model, AutoEncoder):
loss = loss_fn(pred, x.float() / 255.0)
else:
loss = loss_fn(pred, y)
losses.append(loss.item()) losses.append(loss.item())
avg_loss = sum(losses) / len(losses) avg_loss = sum(losses) / len(losses)

View file

@ -10,7 +10,8 @@ class Trainer(ABC):
@abstractmethod @abstractmethod
def execute( def execute(
self, self,
model: nn.Module | None, model: nn.Module | type[nn.Module] | None,
context_length: int,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int | None, n_epochs: int | None,