changes to training + added autoencoder
This commit is contained in:
parent
6e591bb470
commit
a4a41d190b
7 changed files with 91 additions and 55 deletions
|
|
@ -5,15 +5,16 @@ from src.models import Model
|
|||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, latent_dim):
|
||||
def __init__(self, data_length, channel_count, latent_dim):
|
||||
super(Encoder, self).__init__()
|
||||
self._encoder = nn.Sequential(*[
|
||||
nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(hidden_size),
|
||||
nn.Conv1d(1, channel_count, kernel_size=3, padding=1), # (hidden_size, L)
|
||||
nn.BatchNorm1d(channel_count),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_size, 2 * hidden_size, stride=2, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(2 * hidden_size),
|
||||
nn.Linear(2 * hidden_size, latent_dim),
|
||||
nn.Conv1d(channel_count, 2 * channel_count, stride=2, kernel_size=3, padding=1), # (2 * hidden_size, L / 2)
|
||||
nn.BatchNorm1d(2 * channel_count),
|
||||
nn.Flatten(), # 2 * hidden_size * L / 2
|
||||
nn.Linear(2 * channel_count * data_length // 2, latent_dim),
|
||||
nn.ReLU()
|
||||
])
|
||||
|
||||
|
|
@ -22,27 +23,28 @@ class Encoder(nn.Module):
|
|||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, output_size):
|
||||
def __init__(self, latent_dim, channel_count, data_length):
|
||||
super(Decoder, self).__init__()
|
||||
super._decoder = nn.Sequential(*[
|
||||
nn.Linear(input_size, 2 * hidden_size),
|
||||
self._decoder = nn.Sequential(*[
|
||||
nn.Linear(latent_dim, 2 * channel_count * data_length // 2),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(2 * hidden_size),
|
||||
nn.ConvTranspose1d(2 * hidden_size, hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
nn.BatchNorm1d(hidden_size),
|
||||
nn.Unflatten(1, (2 * channel_count, data_length // 2)),
|
||||
nn.BatchNorm1d(2 * channel_count),
|
||||
nn.ConvTranspose1d(2 * channel_count, channel_count, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
nn.BatchNorm1d(channel_count),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose1d(hidden_size, output_size, kernel_size=3, padding=1),
|
||||
nn.ConvTranspose1d(channel_count, 1, kernel_size=3, padding=1),
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self._decoder(x)
|
||||
|
||||
class AutoEncoder(Model):
|
||||
def __init__(self, input_size, hidden_size, latent_dim):
|
||||
super().__init__(loss_function = nn.CrossEntropyLoss())
|
||||
def __init__(self, input_size, channel_count, latent_dim):
|
||||
super().__init__(loss_function = nn.MSELoss())
|
||||
|
||||
self.encoder = Encoder(input_size, hidden_size, latent_dim)
|
||||
self.decoder = Decoder(latent_dim, hidden_size, input_size)
|
||||
self.encoder = Encoder(input_size, channel_count, latent_dim)
|
||||
self.decoder = Decoder(latent_dim, channel_count, input_size)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
|
@ -50,5 +52,11 @@ class AutoEncoder(Model):
|
|||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(x)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.decode(self.encode(x))
|
||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||
x = x.float() / 255.0 # convert to floats
|
||||
x = x.unsqueeze(1) # add channel dimension --> (B, 1, L)
|
||||
|
||||
encoded = self.encoder(x)
|
||||
decoded = self.decoder(encoded)
|
||||
|
||||
return decoded
|
||||
|
|
@ -18,10 +18,13 @@ class CNNPredictor(Model):
|
|||
# 2. Convolutional feature extractor
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ def train(
|
|||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
if model_name:
|
||||
print("Creating model")
|
||||
print(f"Creating model: {model_name}")
|
||||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
|
|
@ -64,6 +64,7 @@ def train(
|
|||
print("Training")
|
||||
best_model = trainer.execute(
|
||||
model=model,
|
||||
context_length=context_length,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
n_epochs=n_trials,
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class FullTrainer(Trainer):
|
|||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
context_length: int,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int | None,
|
||||
|
|
|
|||
|
|
@ -5,30 +5,37 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
match model.__class__:
|
||||
case CNNPredictor.__class__:
|
||||
return model(
|
||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case ByteTransformer.__class__:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return model(
|
||||
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
|
||||
if model_cls is CNNPredictor:
|
||||
return CNNPredictor(
|
||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
if model_cls is ByteTransformer:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return ByteTransformer(
|
||||
d_model=context_length,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
if model_cls is AutoEncoder:
|
||||
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
|
||||
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
|
||||
return AutoEncoder(
|
||||
channel_count=channel_count,
|
||||
latent_dim=latent_dim,
|
||||
input_size=context_length,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
|
@ -36,10 +43,11 @@ def objective_function(
|
|||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
model: Model,
|
||||
model: type[Model],
|
||||
context_length: int,
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial, model).to(device)
|
||||
model = create_model(trial, model, context_length).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
||||
return min(validation_loss)
|
||||
|
||||
|
|
@ -52,7 +60,8 @@ class OptunaTrainer(Trainer):
|
|||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
model: type[Model],
|
||||
context_length,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int,
|
||||
|
|
@ -60,13 +69,19 @@ class OptunaTrainer(Trainer):
|
|||
) -> nn.Module:
|
||||
study = optuna.create_study(direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, model, context_length, device),
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = model(
|
||||
**best_params
|
||||
)
|
||||
if model is AutoEncoder:
|
||||
best_model = AutoEncoder(
|
||||
input_size=context_length,
|
||||
**best_params
|
||||
)
|
||||
elif model is CNNPredictor:
|
||||
best_model = CNNPredictor(**best_params)
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
return best_model
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..models import ByteTransformer, Model
|
||||
from ..models import ByteTransformer, Model, AutoEncoder
|
||||
|
||||
|
||||
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
||||
|
|
@ -53,9 +53,12 @@ def train(
|
|||
y = y.long().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits = _forward(model, x, device)
|
||||
pred = _forward(model, x, device)
|
||||
|
||||
loss = loss_fn(logits, y)
|
||||
if isinstance(model, AutoEncoder):
|
||||
loss = loss_fn(pred, x.float() / 255.0)
|
||||
else:
|
||||
loss = loss_fn(pred, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
|
@ -71,8 +74,12 @@ def train(
|
|||
x = x.long().to(device)
|
||||
y = y.long().to(device)
|
||||
|
||||
logits = _forward(model, x, device)
|
||||
loss = loss_fn(logits, y)
|
||||
pred = _forward(model, x, device)
|
||||
|
||||
if isinstance(model, AutoEncoder):
|
||||
loss = loss_fn(pred, x.float() / 255.0)
|
||||
else:
|
||||
loss = loss_fn(pred, y)
|
||||
losses.append(loss.item())
|
||||
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ class Trainer(ABC):
|
|||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
model: nn.Module | type[nn.Module] | None,
|
||||
context_length: int,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int | None,
|
||||
|
|
|
|||
Reference in a new issue