feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
5
CNN-model/datasets/LoremIpsumDataset.py
Normal file
5
CNN-model/datasets/LoremIpsumDataset.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
import lorem
|
||||
|
||||
class LoremIpsumDataset:
|
||||
def __init__(self):
|
||||
self.data = lorem.text(paragraphs=100)
|
||||
2
CNN-model/datasets/__init__.py
Normal file
2
CNN-model/datasets/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from EnWik9 import EnWik9DataSet
|
||||
from LoremIpsumDataset import LoremIpsumDataset
|
||||
|
|
@ -1,93 +1,55 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import optuna.trial as tr
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
from argparse import ArgumentParser
|
||||
from math import ceil
|
||||
|
||||
from optuna_trial import create_model
|
||||
from utils import make_context_pairs, load_data
|
||||
import optuna
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from datasets import EnWik9DataSet, LoremIpsumDataset
|
||||
from trainers import OptunaTrainer, Trainer
|
||||
|
||||
BATCH_SIZE = 64
|
||||
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
|
||||
|
||||
# hyper parameters
|
||||
context_length = 128
|
||||
|
||||
def train_and_eval(
|
||||
model: nn.Module,
|
||||
training_data: bytes,
|
||||
validation_data: bytes,
|
||||
batch_size: int,
|
||||
epochs: int = 100,
|
||||
learning_rate: float = 1e-3,
|
||||
device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
||||
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size)
|
||||
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size)
|
||||
|
||||
training_losses = []
|
||||
validation_losses = []
|
||||
best_val_loss = float("inf")
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
train_loss = 0
|
||||
for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"):
|
||||
x, y = x.to(device), y.to(device)
|
||||
prediction = model(x)
|
||||
loss = F.cross_entropy(prediction, y)
|
||||
train_loss += loss.item()
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_losses.append(train_loss / len(training_loader))
|
||||
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
val_loss = 0
|
||||
for x, y in validation_loader:
|
||||
x, y = x.to(device), y.to(device)
|
||||
prediction = model(x)
|
||||
loss = F.cross_entropy(prediction, y)
|
||||
val_loss += loss.item()
|
||||
validation_losses.append(val_loss / len(validation_loader))
|
||||
if validation_losses[-1] < best_val_loss:
|
||||
best_val_loss = validation_losses[-1]
|
||||
|
||||
return {
|
||||
"training_losses": training_losses,
|
||||
"validation_losses": validation_losses,
|
||||
"best_validation_loss": best_val_loss
|
||||
}
|
||||
|
||||
|
||||
|
||||
def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int):
|
||||
model = create_model(trial)
|
||||
result = train_and_eval(model, train_data, validation_data, batch_size)
|
||||
return result["best_validation_loss"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--train-data", type=str, required=True)
|
||||
parser.add_argument("--validation-data", type=str, required=True)
|
||||
parser.add_argument("--batch-size", type=int, default=128)
|
||||
|
||||
print(f"Running on device: {DEVICE}...")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--method", choices=["optuna", "train"], required=True)
|
||||
parser.add_argument("--model-path", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
train_data = load_data(args.train_data)
|
||||
validation_data = load_data(args.validation_data)
|
||||
batch_size = args.batch_size
|
||||
if args.method == "train":
|
||||
dataset = EnWik9DataSet()
|
||||
elif args.method == "optuna":
|
||||
dataset = LoremIpsumDataset()
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {args.method}")
|
||||
|
||||
print(f"training data length: {len(train_data)}")
|
||||
print(f"validation data length: {len(validation_data)}")
|
||||
print(f"batch size: {batch_size}")
|
||||
dataset_length = len(dataset)
|
||||
training_size = ceil(0.8 * dataset_length)
|
||||
|
||||
study = optuna.create_study(study_name="CNN network",direction="minimize")
|
||||
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)
|
||||
print(f"training set size = {training_size}, validation set size {dataset_length - training_size}")
|
||||
data = dataset.data["text"]
|
||||
|
||||
train_set, validate_set = torch.utils.data.random_split(TensorDataset(data),
|
||||
[training_size, dataset_length - training_size])
|
||||
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model = None
|
||||
if args.model_path is not None:
|
||||
model = torch.load(args.model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None
|
||||
|
||||
trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
loss_fn=loss_fn,
|
||||
n_epochs=200,
|
||||
device=DEVICE
|
||||
)
|
||||
|
|
|
|||
1
CNN-model/model/__init__.py
Normal file
1
CNN-model/model/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .cnn import CNNPredictor
|
||||
|
|
@ -8,7 +8,7 @@ class CausalConv1d(nn.Conv1d):
|
|||
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
|
||||
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return super().forward(input)
|
||||
return super().forward(input)[:, :, :input.size(-1)]
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
def __init__(
|
||||
|
|
@ -41,5 +41,5 @@ class CNNPredictor(nn.Module):
|
|||
emdedding = emdedding.transpose(1, 2) # B, H, L
|
||||
prediction = self.network(emdedding)
|
||||
last_prediction = prediction[:, :, -1]
|
||||
return softmax(self.output_layer(last_prediction), dim=-1) # convert output of linear layer to prob. distr.
|
||||
return self.output_layer(last_prediction)
|
||||
|
||||
|
|
@ -1,18 +0,0 @@
|
|||
import optuna.trial as tr
|
||||
from cnn import CNNPredictor
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256, context_length: int = 128):
|
||||
num_layers = trial.suggest_int("num_layers", 1, 6)
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
||||
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
||||
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
hidden_dim=hidden_dim,
|
||||
kernel_size=kernel_size,
|
||||
dropout_prob=dropout_prob,
|
||||
use_batchnorm=use_batchnorm
|
||||
)
|
||||
26
CNN-model/trainers/FullTrainer.py
Normal file
26
CNN-model/trainers/FullTrainer.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainer import Trainer
|
||||
from train import train
|
||||
from ..utils import print_losses
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
63
CNN-model/trainers/OptunaTrainer.py
Normal file
63
CNN-model/trainers/OptunaTrainer.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from typing import Callable
|
||||
|
||||
import optuna
|
||||
import optuna.trial as tr
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from trainer import Trainer
|
||||
from ..model.cnn import CNNPredictor
|
||||
from train import train
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||
num_layers = trial.suggest_int("num_layers", 1, 6)
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
||||
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
||||
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
num_layers=num_layers,
|
||||
hidden_dim=hidden_dim,
|
||||
kernel_size=kernel_size,
|
||||
dropout_prob=dropout_prob,
|
||||
use_batchnorm=use_batchnorm
|
||||
)
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
n_trials=20
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = CNNPredictor(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, "models/final_model.pt")
|
||||
2
CNN-model/trainers/__init__.py
Normal file
2
CNN-model/trainers/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from OptunaTrainer import OptunaTrainer
|
||||
from trainer import Trainer
|
||||
50
CNN-model/trainers/train.py
Normal file
50
CNN-model/trainers/train.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from typing import Callable
|
||||
|
||||
|
||||
def train(
|
||||
model: nn.Module,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
epochs: int = 100,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8
|
||||
) -> tuple[list[float], list[float]]:
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
total_loss = []
|
||||
|
||||
for data in tqdm(training_loader):
|
||||
optimizer.zero_grad()
|
||||
|
||||
x_hat = model(data)
|
||||
|
||||
loss = loss_fn(x_hat, data)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss.append(loss.item())
|
||||
|
||||
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
||||
|
||||
with torch.no_grad():
|
||||
losses = []
|
||||
for data in validation_loader:
|
||||
x_hat = model(data)
|
||||
loss = loss_fn(x_hat, data)
|
||||
losses.append(loss.item())
|
||||
avg_loss = sum(losses) / len(losses)
|
||||
avg_validation_losses.append(avg_loss)
|
||||
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
|
||||
|
||||
return avg_training_losses, avg_validation_losses
|
||||
|
||||
22
CNN-model/trainers/trainer.py
Normal file
22
CNN-model/trainers/trainer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
1
CNN-model/utils/__init__.py
Normal file
1
CNN-model/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .utils import *
|
||||
|
|
@ -14,6 +14,17 @@ def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
|||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||
plt.show()
|
||||
|
||||
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
|
||||
plt.plot(train_losses, label="Training loss")
|
||||
plt.plot(validation_losses, label="Validation loss")
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss (cross entropy)")
|
||||
plt.legend()
|
||||
|
||||
if show:
|
||||
plt.show()
|
||||
plt.savefig("losses.png")
|
||||
|
||||
|
||||
def load_data(path: str) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
Reference in a new issue