feat: uhm, i changed some things
This commit is contained in:
parent
b58682cb49
commit
6de4db24cc
27 changed files with 1302 additions and 137 deletions
5
CNN-model/datasets/LoremIpsumDataset.py
Normal file
5
CNN-model/datasets/LoremIpsumDataset.py
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
import lorem
|
||||||
|
|
||||||
|
class LoremIpsumDataset:
|
||||||
|
def __init__(self):
|
||||||
|
self.data = lorem.text(paragraphs=100)
|
||||||
2
CNN-model/datasets/__init__.py
Normal file
2
CNN-model/datasets/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from EnWik9 import EnWik9DataSet
|
||||||
|
from LoremIpsumDataset import LoremIpsumDataset
|
||||||
|
|
@ -1,93 +1,55 @@
|
||||||
import torch
|
from argparse import ArgumentParser
|
||||||
import torch.nn as nn
|
from math import ceil
|
||||||
import torch.nn.functional as F
|
|
||||||
import optuna.trial as tr
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
from optuna_trial import create_model
|
import torch
|
||||||
from utils import make_context_pairs, load_data
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
import optuna
|
|
||||||
|
from datasets import EnWik9DataSet, LoremIpsumDataset
|
||||||
|
from trainers import OptunaTrainer, Trainer
|
||||||
|
|
||||||
|
BATCH_SIZE = 64
|
||||||
|
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
|
||||||
|
|
||||||
# hyper parameters
|
# hyper parameters
|
||||||
context_length = 128
|
context_length = 128
|
||||||
|
|
||||||
def train_and_eval(
|
|
||||||
model: nn.Module,
|
|
||||||
training_data: bytes,
|
|
||||||
validation_data: bytes,
|
|
||||||
batch_size: int,
|
|
||||||
epochs: int = 100,
|
|
||||||
learning_rate: float = 1e-3,
|
|
||||||
device: torch.device = torch.device("cpu")
|
|
||||||
) -> dict:
|
|
||||||
model.to(device)
|
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
|
|
||||||
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size)
|
|
||||||
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size)
|
|
||||||
|
|
||||||
training_losses = []
|
|
||||||
validation_losses = []
|
|
||||||
best_val_loss = float("inf")
|
|
||||||
|
|
||||||
for epoch in range(epochs):
|
|
||||||
model.train()
|
|
||||||
train_loss = 0
|
|
||||||
for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"):
|
|
||||||
x, y = x.to(device), y.to(device)
|
|
||||||
prediction = model(x)
|
|
||||||
loss = F.cross_entropy(prediction, y)
|
|
||||||
train_loss += loss.item()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
training_losses.append(train_loss / len(training_loader))
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
val_loss = 0
|
|
||||||
for x, y in validation_loader:
|
|
||||||
x, y = x.to(device), y.to(device)
|
|
||||||
prediction = model(x)
|
|
||||||
loss = F.cross_entropy(prediction, y)
|
|
||||||
val_loss += loss.item()
|
|
||||||
validation_losses.append(val_loss / len(validation_loader))
|
|
||||||
if validation_losses[-1] < best_val_loss:
|
|
||||||
best_val_loss = validation_losses[-1]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"training_losses": training_losses,
|
|
||||||
"validation_losses": validation_losses,
|
|
||||||
"best_validation_loss": best_val_loss
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int):
|
|
||||||
model = create_model(trial)
|
|
||||||
result = train_and_eval(model, train_data, validation_data, batch_size)
|
|
||||||
return result["best_validation_loss"]
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
print(f"Running on device: {DEVICE}...")
|
||||||
parser.add_argument("--train-data", type=str, required=True)
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--validation-data", type=str, required=True)
|
parser.add_argument("--method", choices=["optuna", "train"], required=True)
|
||||||
parser.add_argument("--batch-size", type=int, default=128)
|
parser.add_argument("--model-path", type=str, required=False)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
print(args)
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
if args.method == "train":
|
||||||
train_data = load_data(args.train_data)
|
dataset = EnWik9DataSet()
|
||||||
validation_data = load_data(args.validation_data)
|
elif args.method == "optuna":
|
||||||
batch_size = args.batch_size
|
dataset = LoremIpsumDataset()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown method: {args.method}")
|
||||||
|
|
||||||
print(f"training data length: {len(train_data)}")
|
dataset_length = len(dataset)
|
||||||
print(f"validation data length: {len(validation_data)}")
|
training_size = ceil(0.8 * dataset_length)
|
||||||
print(f"batch size: {batch_size}")
|
|
||||||
|
|
||||||
study = optuna.create_study(study_name="CNN network",direction="minimize")
|
print(f"training set size = {training_size}, validation set size {dataset_length - training_size}")
|
||||||
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10)
|
data = dataset.data["text"]
|
||||||
|
|
||||||
|
train_set, validate_set = torch.utils.data.random_split(TensorDataset(data),
|
||||||
|
[training_size, dataset_length - training_size])
|
||||||
|
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
|
||||||
|
loss_fn = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
model = None
|
||||||
|
if args.model_path is not None:
|
||||||
|
model = torch.load(args.model_path)
|
||||||
|
|
||||||
|
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None
|
||||||
|
|
||||||
|
trainer.execute(
|
||||||
|
model=model,
|
||||||
|
train_loader=training_loader,
|
||||||
|
validation_loader=validation_loader,
|
||||||
|
loss_fn=loss_fn,
|
||||||
|
n_epochs=200,
|
||||||
|
device=DEVICE
|
||||||
|
)
|
||||||
|
|
|
||||||
1
CNN-model/model/__init__.py
Normal file
1
CNN-model/model/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .cnn import CNNPredictor
|
||||||
|
|
@ -8,7 +8,7 @@ class CausalConv1d(nn.Conv1d):
|
||||||
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
|
def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
|
||||||
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
|
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
|
||||||
def forward(self, input: Tensor) -> Tensor:
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
return super().forward(input)
|
return super().forward(input)[:, :, :input.size(-1)]
|
||||||
|
|
||||||
class CNNPredictor(nn.Module):
|
class CNNPredictor(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -41,5 +41,5 @@ class CNNPredictor(nn.Module):
|
||||||
emdedding = emdedding.transpose(1, 2) # B, H, L
|
emdedding = emdedding.transpose(1, 2) # B, H, L
|
||||||
prediction = self.network(emdedding)
|
prediction = self.network(emdedding)
|
||||||
last_prediction = prediction[:, :, -1]
|
last_prediction = prediction[:, :, -1]
|
||||||
return softmax(self.output_layer(last_prediction), dim=-1) # convert output of linear layer to prob. distr.
|
return self.output_layer(last_prediction)
|
||||||
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
import optuna.trial as tr
|
|
||||||
from cnn import CNNPredictor
|
|
||||||
|
|
||||||
def create_model(trial: tr.Trial, vocab_size: int = 256, context_length: int = 128):
|
|
||||||
num_layers = trial.suggest_int("num_layers", 1, 6)
|
|
||||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
|
||||||
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
|
||||||
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
|
||||||
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
|
||||||
|
|
||||||
return CNNPredictor(
|
|
||||||
vocab_size=vocab_size,
|
|
||||||
num_layers=num_layers,
|
|
||||||
hidden_dim=hidden_dim,
|
|
||||||
kernel_size=kernel_size,
|
|
||||||
dropout_prob=dropout_prob,
|
|
||||||
use_batchnorm=use_batchnorm
|
|
||||||
)
|
|
||||||
26
CNN-model/trainers/FullTrainer.py
Normal file
26
CNN-model/trainers/FullTrainer.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from trainer import Trainer
|
||||||
|
from train import train
|
||||||
|
from ..utils import print_losses
|
||||||
|
|
||||||
|
class FullTrainer(Trainer):
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
model: nn.Module | None,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
validation_loader: DataLoader,
|
||||||
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
n_epochs: int,
|
||||||
|
device: str
|
||||||
|
) -> None:
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||||
|
print_losses(train_loss, val_loss)
|
||||||
63
CNN-model/trainers/OptunaTrainer.py
Normal file
63
CNN-model/trainers/OptunaTrainer.py
Normal file
|
|
@ -0,0 +1,63 @@
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import optuna
|
||||||
|
import optuna.trial as tr
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from trainer import Trainer
|
||||||
|
from ..model.cnn import CNNPredictor
|
||||||
|
from train import train
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||||
|
num_layers = trial.suggest_int("num_layers", 1, 6)
|
||||||
|
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||||
|
kernel_size = trial.suggest_int("kernel_size", 2, 7)
|
||||||
|
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
|
||||||
|
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
|
||||||
|
|
||||||
|
return CNNPredictor(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
num_layers=num_layers,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
dropout_prob=dropout_prob,
|
||||||
|
use_batchnorm=use_batchnorm
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def objective_function(
|
||||||
|
trial: tr.Trial,
|
||||||
|
training_loader: DataLoader,
|
||||||
|
validation_loader: DataLoader,
|
||||||
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
device: str
|
||||||
|
):
|
||||||
|
model = create_model(trial).to(device)
|
||||||
|
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||||
|
return min(validation_loss)
|
||||||
|
|
||||||
|
|
||||||
|
class OptunaTrainer(Trainer):
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
model: nn.Module | None,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
validation_loader: DataLoader,
|
||||||
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
n_epochs: int,
|
||||||
|
device: str
|
||||||
|
) -> None:
|
||||||
|
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||||
|
study.optimize(
|
||||||
|
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||||
|
n_trials=20
|
||||||
|
)
|
||||||
|
|
||||||
|
best_params = study.best_trial.params
|
||||||
|
best_model = CNNPredictor(
|
||||||
|
**best_params
|
||||||
|
)
|
||||||
|
torch.save(best_model, "models/final_model.pt")
|
||||||
2
CNN-model/trainers/__init__.py
Normal file
2
CNN-model/trainers/__init__.py
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
from OptunaTrainer import OptunaTrainer
|
||||||
|
from trainer import Trainer
|
||||||
50
CNN-model/trainers/train.py
Normal file
50
CNN-model/trainers/train.py
Normal file
|
|
@ -0,0 +1,50 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data.dataloader import DataLoader
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
def train(
|
||||||
|
model: nn.Module,
|
||||||
|
training_loader: DataLoader,
|
||||||
|
validation_loader: DataLoader,
|
||||||
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
epochs: int = 100,
|
||||||
|
learning_rate: float = 1e-3,
|
||||||
|
weight_decay: float = 1e-8
|
||||||
|
) -> tuple[list[float], list[float]]:
|
||||||
|
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
avg_training_losses = []
|
||||||
|
avg_validation_losses = []
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
total_loss = []
|
||||||
|
|
||||||
|
for data in tqdm(training_loader):
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
x_hat = model(data)
|
||||||
|
|
||||||
|
loss = loss_fn(x_hat, data)
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
total_loss.append(loss.item())
|
||||||
|
|
||||||
|
avg_training_losses.append(sum(total_loss) / len(total_loss))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
losses = []
|
||||||
|
for data in validation_loader:
|
||||||
|
x_hat = model(data)
|
||||||
|
loss = loss_fn(x_hat, data)
|
||||||
|
losses.append(loss.item())
|
||||||
|
avg_loss = sum(losses) / len(losses)
|
||||||
|
avg_validation_losses.append(avg_loss)
|
||||||
|
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
|
||||||
|
|
||||||
|
return avg_training_losses, avg_validation_losses
|
||||||
|
|
||||||
22
CNN-model/trainers/trainer.py
Normal file
22
CNN-model/trainers/trainer.py
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
class Trainer(ABC):
|
||||||
|
"""Abstract class for trainers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(
|
||||||
|
self,
|
||||||
|
model: nn.Module | None,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
validation_loader: DataLoader,
|
||||||
|
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||||
|
n_epochs: int,
|
||||||
|
device: str
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
1
CNN-model/utils/__init__.py
Normal file
1
CNN-model/utils/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .utils import *
|
||||||
|
|
@ -14,6 +14,17 @@ def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
||||||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
|
||||||
|
plt.plot(train_losses, label="Training loss")
|
||||||
|
plt.plot(validation_losses, label="Validation loss")
|
||||||
|
plt.xlabel("Epoch")
|
||||||
|
plt.ylabel("Loss (cross entropy)")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
if show:
|
||||||
|
plt.show()
|
||||||
|
plt.savefig("losses.png")
|
||||||
|
|
||||||
|
|
||||||
def load_data(path: str) -> bytes:
|
def load_data(path: str) -> bytes:
|
||||||
with open(path, "rb") as f:
|
with open(path, "rb") as f:
|
||||||
|
|
@ -36,7 +36,7 @@ def CDF_fn(pz, bin_width, variable_type, distribution_type):
|
||||||
|
|
||||||
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
|
||||||
bin_locations = bin_locations.float() * bin_width
|
bin_locations = bin_locations.float() * bin_width
|
||||||
bin_locations = bin_locations.to(device=pz[0].device)
|
bin_locations = bin_locations.to(device=pz[0].DEVICE)
|
||||||
|
|
||||||
pz = [param[:, :, :, :, None] for param in pz]
|
pz = [param[:, :, :, :, None] for param in pz]
|
||||||
cdf = cdf_fn(
|
cdf = cdf_fn(
|
||||||
|
|
@ -86,7 +86,7 @@ def decode_sample(
|
||||||
state, pz, variable_type, distribution_type, bin_width=1./256):
|
state, pz, variable_type, distribution_type, bin_width=1./256):
|
||||||
state = rans.unflatten(state)
|
state = rans.unflatten(state)
|
||||||
|
|
||||||
device = pz[0].device
|
device = pz[0].DEVICE
|
||||||
size = pz[0].size()[0:4]
|
size = pz[0].size()[0:4]
|
||||||
|
|
||||||
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ def run(args, kwargs):
|
||||||
import models.Model as Model
|
import models.Model as Model
|
||||||
|
|
||||||
model = Model.Model(args)
|
model = Model.Model(args)
|
||||||
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||||
model.set_temperature(args.temperature)
|
model.set_temperature(args.temperature)
|
||||||
model.enable_hard_round(args.hard_round)
|
model.enable_hard_round(args.hard_round)
|
||||||
|
|
||||||
|
|
@ -208,7 +208,7 @@ def run(args, kwargs):
|
||||||
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
print("Let's use", torch.cuda.device_count(), "GPUs!")
|
||||||
model = torch.nn.DataParallel(model, dim=0)
|
model = torch.nn.DataParallel(model, dim=0)
|
||||||
|
|
||||||
model.to(args.device)
|
model.to(args.DEVICE)
|
||||||
|
|
||||||
def lr_lambda(epoch):
|
def lr_lambda(epoch):
|
||||||
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ def _stacked_sigmoid(x, temperature, n_approx=3):
|
||||||
x_remainder = x_remainder.view(size + (1,))
|
x_remainder = x_remainder.view(size + (1,))
|
||||||
|
|
||||||
translation = torch.arange(n_approx) - n_approx // 2
|
translation = torch.arange(n_approx) - n_approx // 2
|
||||||
translation = translation.to(device=x.device, dtype=x.dtype)
|
translation = translation.to(device=x.DEVICE, dtype=x.dtype)
|
||||||
translation = translation.view([1] * len(size) + [len(translation)])
|
translation = translation.view([1] * len(size) + [len(translation)])
|
||||||
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ def train(epoch, train_loader, model, opt, args):
|
||||||
for batch_idx, (data, _) in enumerate(train_loader):
|
for batch_idx, (data, _) in enumerate(train_loader):
|
||||||
data = data.view(-1, *args.input_size)
|
data = data.view(-1, *args.input_size)
|
||||||
|
|
||||||
data = data.to(args.device)
|
data = data.to(args.DEVICE)
|
||||||
|
|
||||||
opt.zero_grad()
|
opt.zero_grad()
|
||||||
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,12 @@ description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"datasets>=4.4.1",
|
||||||
|
"lorem>=0.1.1",
|
||||||
"matplotlib>=3.10.7",
|
"matplotlib>=3.10.7",
|
||||||
"numpy>=2.3.4",
|
"numpy>=2.3.4",
|
||||||
"optuna>=4.5.0",
|
"optuna>=4.5.0",
|
||||||
"torch>=2.9.0",
|
"torch>=2.9.0",
|
||||||
|
"torchdata>=0.11.0",
|
||||||
|
"torchvision>=0.24.0",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -176,9 +176,9 @@ class RelMultiHeadAttn(nn.Module):
|
||||||
def _shift(self, x, qlen, klen, mask, left=False):
|
def _shift(self, x, qlen, klen, mask, left=False):
|
||||||
if qlen > 1:
|
if qlen > 1:
|
||||||
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
|
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
|
||||||
device=x.device, dtype=x.dtype)
|
device=x.DEVICE, dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
|
zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype)
|
||||||
|
|
||||||
if left:
|
if left:
|
||||||
mask = mask.flip(1)
|
mask = mask.flip(1)
|
||||||
|
|
@ -193,7 +193,7 @@ class RelMultiHeadAttn(nn.Module):
|
||||||
|
|
||||||
def _rel_shift(self, x, zero_triu=False):
|
def _rel_shift(self, x, zero_triu=False):
|
||||||
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
||||||
device=x.device, dtype=x.dtype)
|
device=x.DEVICE, dtype=x.dtype)
|
||||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||||
|
|
||||||
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
||||||
|
|
@ -661,7 +661,7 @@ class MemTransformerLM(nn.Module):
|
||||||
|
|
||||||
hids = []
|
hids = []
|
||||||
if self.attn_type == 0: # default
|
if self.attn_type == 0: # default
|
||||||
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
|
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
|
||||||
dtype=word_emb.dtype)
|
dtype=word_emb.dtype)
|
||||||
if self.clamp_len > 0:
|
if self.clamp_len > 0:
|
||||||
pos_seq.clamp_(max=self.clamp_len)
|
pos_seq.clamp_(max=self.clamp_len)
|
||||||
|
|
@ -691,7 +691,7 @@ class MemTransformerLM(nn.Module):
|
||||||
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
||||||
hids.append(core_out)
|
hids.append(core_out)
|
||||||
elif self.attn_type == 2: # absolute
|
elif self.attn_type == 2: # absolute
|
||||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
|
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
|
||||||
dtype=word_emb.dtype)
|
dtype=word_emb.dtype)
|
||||||
if self.clamp_len > 0:
|
if self.clamp_len > 0:
|
||||||
pos_seq.clamp_(max=self.clamp_len)
|
pos_seq.clamp_(max=self.clamp_len)
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ np.random.seed(args.seed)
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
if not args.cuda:
|
if not args.cuda:
|
||||||
print('WARNING: You have a CUDA device, so you should probably run with --cuda')
|
print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda')
|
||||||
else:
|
else:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class AdaptiveLogSoftmax(nn.Module):
|
||||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||||
|
|
||||||
nll = torch.zeros_like(target,
|
nll = torch.zeros_like(target,
|
||||||
dtype=hidden.dtype, device=hidden.device)
|
dtype=hidden.dtype, device=hidden.DEVICE)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
cutoff_values = [0] + self.cutoffs
|
cutoff_values = [0] + self.cutoffs
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class LogUniformSampler(object):
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
||||||
device = labels.device
|
device = labels.DEVICE
|
||||||
neg_samples = neg_samples.to(device)
|
neg_samples = neg_samples.to(device)
|
||||||
true_log_probs = self.log_q[labels].to(device)
|
true_log_probs = self.log_q[labels].to(device)
|
||||||
samp_log_probs = self.log_q[neg_samples].to(device)
|
samp_log_probs = self.log_q[neg_samples].to(device)
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||||
|
|
||||||
nll = torch.zeros_like(target,
|
nll = torch.zeros_like(target,
|
||||||
dtype=hidden.dtype, device=hidden.device)
|
dtype=hidden.dtype, device=hidden.DEVICE)
|
||||||
|
|
||||||
offset = 0
|
offset = 0
|
||||||
cutoff_values = [0] + self.cutoffs
|
cutoff_values = [0] + self.cutoffs
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
|
def assign_to_gpu(gpu=0, ps_dev="/DEVICE:CPU:0"):
|
||||||
def _assign(op):
|
def _assign(op):
|
||||||
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
|
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
|
||||||
if node_def.op == "Variable":
|
if node_def.op == "Variable":
|
||||||
|
|
|
||||||
|
|
@ -724,7 +724,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
|
||||||
|
|
||||||
hooks = []
|
hooks = []
|
||||||
|
|
||||||
with ops.device(device):
|
with ops.DEVICE(device):
|
||||||
user_context = tpu_context.TPUContext(
|
user_context = tpu_context.TPUContext(
|
||||||
internal_ctx=ctx,
|
internal_ctx=ctx,
|
||||||
input_device=device,
|
input_device=device,
|
||||||
|
|
@ -758,7 +758,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
|
||||||
Returns:
|
Returns:
|
||||||
list of dict of ops.
|
list of dict of ops.
|
||||||
"""
|
"""
|
||||||
with ops.device(device):
|
with ops.DEVICE(device):
|
||||||
num_of_replicas_per_host = ctx.num_of_replicas_per_host
|
num_of_replicas_per_host = ctx.num_of_replicas_per_host
|
||||||
# Convert user input to features and labels. If the user returns a
|
# Convert user input to features and labels. If the user returns a
|
||||||
# dataset, it is initialized and the features and labels extracted via
|
# dataset, it is initialized and the features and labels extracted via
|
||||||
|
|
@ -799,7 +799,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
||||||
captured_infeed_queue = _CapturedObject()
|
captured_infeed_queue = _CapturedObject()
|
||||||
hooks = []
|
hooks = []
|
||||||
|
|
||||||
with ops.device(device):
|
with ops.DEVICE(device):
|
||||||
user_context = tpu_context.TPUContext(
|
user_context = tpu_context.TPUContext(
|
||||||
internal_ctx=ctx,
|
internal_ctx=ctx,
|
||||||
input_device=device,
|
input_device=device,
|
||||||
|
|
@ -827,7 +827,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
|
||||||
per_host_sharded_inputs = []
|
per_host_sharded_inputs = []
|
||||||
num_replicas_per_host = ctx.num_of_replicas_per_host
|
num_replicas_per_host = ctx.num_of_replicas_per_host
|
||||||
cached_signals = None
|
cached_signals = None
|
||||||
with ops.device(device):
|
with ops.DEVICE(device):
|
||||||
if not inputs.is_dataset:
|
if not inputs.is_dataset:
|
||||||
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
|
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
|
||||||
for _ in range(num_replicas_per_host):
|
for _ in range(num_replicas_per_host):
|
||||||
|
|
@ -888,7 +888,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
|
||||||
captured_infeed_queue = _CapturedObject()
|
captured_infeed_queue = _CapturedObject()
|
||||||
hooks = []
|
hooks = []
|
||||||
device_0 = ctx.tpu_host_placement_function(host_id=0)
|
device_0 = ctx.tpu_host_placement_function(host_id=0)
|
||||||
with ops.device(device_0):
|
with ops.DEVICE(device_0):
|
||||||
user_context = tpu_context.TPUContext(
|
user_context = tpu_context.TPUContext(
|
||||||
internal_ctx=ctx, input_device=device_0, invocation_index=0)
|
internal_ctx=ctx, input_device=device_0, invocation_index=0)
|
||||||
inputs = _Inputs.from_input_fn(input_fn(user_context))
|
inputs = _Inputs.from_input_fn(input_fn(user_context))
|
||||||
|
|
@ -924,7 +924,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
|
||||||
flattened_inputs = None # Cache result from input_fn.
|
flattened_inputs = None # Cache result from input_fn.
|
||||||
signals = None
|
signals = None
|
||||||
for host_id in xrange(num_hosts):
|
for host_id in xrange(num_hosts):
|
||||||
with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
|
with ops.DEVICE(ctx.tpu_host_placement_function(host_id=host_id)):
|
||||||
for _ in xrange(ctx.num_of_replicas_per_host):
|
for _ in xrange(ctx.num_of_replicas_per_host):
|
||||||
# Note: input_fn is only called once at host 0 for the first replica.
|
# Note: input_fn is only called once at host 0 for the first replica.
|
||||||
# The features and labels returned from that invocation are
|
# The features and labels returned from that invocation are
|
||||||
|
|
@ -1147,7 +1147,7 @@ class _InputPipeline(object):
|
||||||
|
|
||||||
def dequeue_fn():
|
def dequeue_fn():
|
||||||
"""dequeue_fn is used by TPU to retrieve the tensors."""
|
"""dequeue_fn is used by TPU to retrieve the tensors."""
|
||||||
# In the model-parallel case, both the host-side and device-side
|
# In the model-parallel case, both the host-side and DEVICE-side
|
||||||
# computations must agree on the core on which infeed takes place. We
|
# computations must agree on the core on which infeed takes place. We
|
||||||
# choose to perform infeed on logical core 0 of each replica.
|
# choose to perform infeed on logical core 0 of each replica.
|
||||||
values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
|
values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
|
||||||
|
|
@ -1173,7 +1173,7 @@ class _InputPipeline(object):
|
||||||
# host.
|
# host.
|
||||||
for host_id in range(num_hosts):
|
for host_id in range(num_hosts):
|
||||||
host_device = tpu_host_placement_fn(host_id=host_id)
|
host_device = tpu_host_placement_fn(host_id=host_id)
|
||||||
with ops.device(host_device):
|
with ops.DEVICE(host_device):
|
||||||
with ops.name_scope('input_pipeline_task%d' % (host_id)):
|
with ops.name_scope('input_pipeline_task%d' % (host_id)):
|
||||||
enqueue_ops_fn, captured_infeed_queue = (
|
enqueue_ops_fn, captured_infeed_queue = (
|
||||||
generate_per_core_enqueue_ops_fn_for_host(
|
generate_per_core_enqueue_ops_fn_for_host(
|
||||||
|
|
@ -1211,7 +1211,7 @@ class _InputPipeline(object):
|
||||||
else:
|
else:
|
||||||
for host_id in range(num_hosts):
|
for host_id in range(num_hosts):
|
||||||
host_device = tpu_host_placement_fn(host_id=host_id)
|
host_device = tpu_host_placement_fn(host_id=host_id)
|
||||||
with ops.device(host_device):
|
with ops.DEVICE(host_device):
|
||||||
with ops.name_scope('input_pipeline_task%d' % (host_id)):
|
with ops.name_scope('input_pipeline_task%d' % (host_id)):
|
||||||
if self._ctx.is_input_per_host_with_iterators():
|
if self._ctx.is_input_per_host_with_iterators():
|
||||||
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
|
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
|
||||||
|
|
@ -1712,7 +1712,7 @@ class _OutfeedHostCall(object):
|
||||||
for name in self._names:
|
for name in self._names:
|
||||||
tensors.extend(self._tensors[name])
|
tensors.extend(self._tensors[name])
|
||||||
|
|
||||||
with ops.device(tpu.core(0)):
|
with ops.DEVICE(tpu.core(0)):
|
||||||
return [tpu_ops.outfeed_enqueue_tuple(tensors)]
|
return [tpu_ops.outfeed_enqueue_tuple(tensors)]
|
||||||
|
|
||||||
def create_tpu_hostcall(self):
|
def create_tpu_hostcall(self):
|
||||||
|
|
@ -1751,7 +1751,7 @@ class _OutfeedHostCall(object):
|
||||||
# per replica.
|
# per replica.
|
||||||
for i in xrange(self._ctx.num_replicas):
|
for i in xrange(self._ctx.num_replicas):
|
||||||
host_device, ordinal_id = self._ctx.device_for_replica(i)
|
host_device, ordinal_id = self._ctx.device_for_replica(i)
|
||||||
with ops.device(host_device):
|
with ops.DEVICE(host_device):
|
||||||
outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
|
outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
|
||||||
dtypes=tensor_dtypes,
|
dtypes=tensor_dtypes,
|
||||||
shapes=tensor_shapes,
|
shapes=tensor_shapes,
|
||||||
|
|
@ -1770,7 +1770,7 @@ class _OutfeedHostCall(object):
|
||||||
# place all ops on tpu host if possible.
|
# place all ops on tpu host if possible.
|
||||||
#
|
#
|
||||||
# TODO(jhseu): Evaluate whether this is right for summaries.
|
# TODO(jhseu): Evaluate whether this is right for summaries.
|
||||||
with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
|
with ops.DEVICE(self._ctx.tpu_host_placement_function(replica_id=0)):
|
||||||
for name in self._names:
|
for name in self._names:
|
||||||
dequeue_ops = dequeue_ops_by_name[name]
|
dequeue_ops = dequeue_ops_by_name[name]
|
||||||
for i, item in enumerate(dequeue_ops):
|
for i, item in enumerate(dequeue_ops):
|
||||||
|
|
@ -2426,7 +2426,7 @@ class TPUEstimator(estimator_lib.Estimator):
|
||||||
# For export_savedmodel, input_fn is never passed to Estimator. So,
|
# For export_savedmodel, input_fn is never passed to Estimator. So,
|
||||||
# `is_export_mode` must be False.
|
# `is_export_mode` must be False.
|
||||||
if ctx.is_running_on_cpu(is_export_mode=False):
|
if ctx.is_running_on_cpu(is_export_mode=False):
|
||||||
with ops.device('/device:CPU:0'):
|
with ops.DEVICE('/DEVICE:CPU:0'):
|
||||||
return input_fn(**kwargs)
|
return input_fn(**kwargs)
|
||||||
|
|
||||||
# For TPU computation, input_fn should be invoked in a tf.while_loop for
|
# For TPU computation, input_fn should be invoked in a tf.while_loop for
|
||||||
|
|
@ -2971,7 +2971,7 @@ def _wrap_computation_in_while_loop(device, op_fn):
|
||||||
iterations_per_loop_var = _create_or_get_iterations_per_loop()
|
iterations_per_loop_var = _create_or_get_iterations_per_loop()
|
||||||
# By setting parallel_iterations=1, the parallel execution in while_loop is
|
# By setting parallel_iterations=1, the parallel execution in while_loop is
|
||||||
# basically turned off.
|
# basically turned off.
|
||||||
with ops.device(device):
|
with ops.DEVICE(device):
|
||||||
iterations = array_ops.identity(iterations_per_loop_var)
|
iterations = array_ops.identity(iterations_per_loop_var)
|
||||||
return control_flow_ops.while_loop(
|
return control_flow_ops.while_loop(
|
||||||
lambda i: i < iterations,
|
lambda i: i < iterations,
|
||||||
|
|
@ -2995,7 +2995,7 @@ def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
|
||||||
|
|
||||||
# By setting parallel_iterations=1, the parallel execution in while_loop is
|
# By setting parallel_iterations=1, the parallel execution in while_loop is
|
||||||
# basically turned off.
|
# basically turned off.
|
||||||
with ops.device(device):
|
with ops.DEVICE(device):
|
||||||
return control_flow_ops.while_loop(
|
return control_flow_ops.while_loop(
|
||||||
cond,
|
cond,
|
||||||
computation, [_StopSignals.NON_STOPPING_SIGNAL],
|
computation, [_StopSignals.NON_STOPPING_SIGNAL],
|
||||||
|
|
@ -3006,7 +3006,7 @@ def _validate_tpu_training_graph():
|
||||||
"""Validate graph before running distributed training.
|
"""Validate graph before running distributed training.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If the graph seems invalid for running on device
|
ValueError: If the graph seems invalid for running on DEVICE
|
||||||
"""
|
"""
|
||||||
operations = ops.get_default_graph().get_operations()
|
operations = ops.get_default_graph().get_operations()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -249,7 +249,7 @@ def train(n_token, cutoffs, ps_device):
|
||||||
|
|
||||||
for i in range(FLAGS.num_core_per_host):
|
for i in range(FLAGS.num_core_per_host):
|
||||||
reuse = True if i > 0 else None
|
reuse = True if i > 0 else None
|
||||||
with tf.device(assign_to_gpu(i, ps_device)), \
|
with tf.DEVICE(assign_to_gpu(i, ps_device)), \
|
||||||
tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
|
tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
|
||||||
|
|
||||||
mems_i = [tf.placeholder(tf.float32,
|
mems_i = [tf.placeholder(tf.float32,
|
||||||
|
|
@ -384,7 +384,7 @@ def evaluate(n_token, cutoffs, ps_device):
|
||||||
tower_mems, tower_losses, tower_new_mems = [], [], []
|
tower_mems, tower_losses, tower_new_mems = [], [], []
|
||||||
|
|
||||||
for i in range(FLAGS.num_core_per_host):
|
for i in range(FLAGS.num_core_per_host):
|
||||||
with tf.device(assign_to_gpu(i, ps_device)), \
|
with tf.DEVICE(assign_to_gpu(i, ps_device)), \
|
||||||
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
|
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
|
||||||
|
|
||||||
mems_i = [tf.placeholder(tf.float32,
|
mems_i = [tf.placeholder(tf.float32,
|
||||||
|
|
|
||||||
Reference in a new issue