feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -0,0 +1,5 @@
import lorem
class LoremIpsumDataset:
def __init__(self):
self.data = lorem.text(paragraphs=100)

View file

@ -0,0 +1,2 @@
from EnWik9 import EnWik9DataSet
from LoremIpsumDataset import LoremIpsumDataset

View file

@ -1,93 +1,55 @@
import torch from argparse import ArgumentParser
import torch.nn as nn from math import ceil
import torch.nn.functional as F
import optuna.trial as tr
from torch.utils.data import DataLoader
from tqdm import tqdm
import argparse
from optuna_trial import create_model import torch
from utils import make_context_pairs, load_data from torch.utils.data import DataLoader, TensorDataset
import optuna
from datasets import EnWik9DataSet, LoremIpsumDataset
from trainers import OptunaTrainer, Trainer
BATCH_SIZE = 64
DEVICE = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
# hyper parameters # hyper parameters
context_length = 128 context_length = 128
def train_and_eval(
model: nn.Module,
training_data: bytes,
validation_data: bytes,
batch_size: int,
epochs: int = 100,
learning_rate: float = 1e-3,
device: torch.device = torch.device("cpu")
) -> dict:
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
training_loader = DataLoader(make_context_pairs(training_data, context_length=context_length), batch_size=batch_size)
validation_loader= DataLoader(make_context_pairs(validation_data, context_length=context_length), batch_size=batch_size)
training_losses = []
validation_losses = []
best_val_loss = float("inf")
for epoch in range(epochs):
model.train()
train_loss = 0
for x, y in tqdm(training_loader, desc=f"Epoch {epoch}"):
x, y = x.to(device), y.to(device)
prediction = model(x)
loss = F.cross_entropy(prediction, y)
train_loss += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
training_losses.append(train_loss / len(training_loader))
model.eval()
with torch.no_grad():
val_loss = 0
for x, y in validation_loader:
x, y = x.to(device), y.to(device)
prediction = model(x)
loss = F.cross_entropy(prediction, y)
val_loss += loss.item()
validation_losses.append(val_loss / len(validation_loader))
if validation_losses[-1] < best_val_loss:
best_val_loss = validation_losses[-1]
return {
"training_losses": training_losses,
"validation_losses": validation_losses,
"best_validation_loss": best_val_loss
}
def objective_function(trial: tr.Trial, train_data: bytes, validation_data: bytes, batch_size: int):
model = create_model(trial)
result = train_and_eval(model, train_data, validation_data, batch_size)
return result["best_validation_loss"]
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() print(f"Running on device: {DEVICE}...")
parser.add_argument("--train-data", type=str, required=True) parser = ArgumentParser()
parser.add_argument("--validation-data", type=str, required=True) parser.add_argument("--method", choices=["optuna", "train"], required=True)
parser.add_argument("--batch-size", type=int, default=128) parser.add_argument("--model-path", type=str, required=False)
args = parser.parse_args() args = parser.parse_args()
print(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.method == "train":
train_data = load_data(args.train_data) dataset = EnWik9DataSet()
validation_data = load_data(args.validation_data) elif args.method == "optuna":
batch_size = args.batch_size dataset = LoremIpsumDataset()
else:
raise ValueError(f"Unknown method: {args.method}")
print(f"training data length: {len(train_data)}") dataset_length = len(dataset)
print(f"validation data length: {len(validation_data)}") training_size = ceil(0.8 * dataset_length)
print(f"batch size: {batch_size}")
study = optuna.create_study(study_name="CNN network",direction="minimize") print(f"training set size = {training_size}, validation set size {dataset_length - training_size}")
study.optimize(lambda trial: objective_function(trial, train_data, validation_data, batch_size), n_trials=10) data = dataset.data["text"]
train_set, validate_set = torch.utils.data.random_split(TensorDataset(data),
[training_size, dataset_length - training_size])
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
loss_fn = torch.nn.CrossEntropyLoss()
model = None
if args.model_path is not None:
model = torch.load(args.model_path)
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else None
trainer.execute(
model=model,
train_loader=training_loader,
validation_loader=validation_loader,
loss_fn=loss_fn,
n_epochs=200,
device=DEVICE
)

View file

@ -0,0 +1 @@
from .cnn import CNNPredictor

View file

@ -8,7 +8,7 @@ class CausalConv1d(nn.Conv1d):
def __init__(self, input_channels, output_channels, kernel_size, **kwargs): def __init__(self, input_channels, output_channels, kernel_size, **kwargs):
super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs) super().__init__(input_channels, output_channels, kernel_size, padding=kernel_size-1, **kwargs)
def forward(self, input: Tensor) -> Tensor: def forward(self, input: Tensor) -> Tensor:
return super().forward(input) return super().forward(input)[:, :, :input.size(-1)]
class CNNPredictor(nn.Module): class CNNPredictor(nn.Module):
def __init__( def __init__(
@ -41,5 +41,5 @@ class CNNPredictor(nn.Module):
emdedding = emdedding.transpose(1, 2) # B, H, L emdedding = emdedding.transpose(1, 2) # B, H, L
prediction = self.network(emdedding) prediction = self.network(emdedding)
last_prediction = prediction[:, :, -1] last_prediction = prediction[:, :, -1]
return softmax(self.output_layer(last_prediction), dim=-1) # convert output of linear layer to prob. distr. return self.output_layer(last_prediction)

View file

@ -1,18 +0,0 @@
import optuna.trial as tr
from cnn import CNNPredictor
def create_model(trial: tr.Trial, vocab_size: int = 256, context_length: int = 128):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
)

View file

@ -0,0 +1,26 @@
from typing import Callable
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from train import train
from ..utils import print_losses
class FullTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
print_losses(train_loss, val_loss)

View file

@ -0,0 +1,63 @@
from typing import Callable
import optuna
import optuna.trial as tr
import torch
from torch import nn as nn
from torch.utils.data import DataLoader
from trainer import Trainer
from ..model.cnn import CNNPredictor
from train import train
def create_model(trial: tr.Trial, vocab_size: int = 256):
num_layers = trial.suggest_int("num_layers", 1, 6)
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
kernel_size = trial.suggest_int("kernel_size", 2, 7)
dropout_prob = trial.suggest_float("dropout_prob", 0.1, 0.5)
use_batchnorm = trial.suggest_categorical("use_batchnorm", [True, False])
return CNNPredictor(
vocab_size=vocab_size,
num_layers=num_layers,
hidden_dim=hidden_dim,
kernel_size=kernel_size,
dropout_prob=dropout_prob,
use_batchnorm=use_batchnorm
)
def objective_function(
trial: tr.Trial,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
device: str
):
model = create_model(trial).to(device)
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
return min(validation_loss)
class OptunaTrainer(Trainer):
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
study = optuna.create_study(study_name="CNN network", direction="minimize")
study.optimize(
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
n_trials=20
)
best_params = study.best_trial.params
best_model = CNNPredictor(
**best_params
)
torch.save(best_model, "models/final_model.pt")

View file

@ -0,0 +1,2 @@
from OptunaTrainer import OptunaTrainer
from trainer import Trainer

View file

@ -0,0 +1,50 @@
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from typing import Callable
def train(
model: nn.Module,
training_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
epochs: int = 100,
learning_rate: float = 1e-3,
weight_decay: float = 1e-8
) -> tuple[list[float], list[float]]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
avg_training_losses = []
avg_validation_losses = []
for epoch in range(epochs):
model.train()
total_loss = []
for data in tqdm(training_loader):
optimizer.zero_grad()
x_hat = model(data)
loss = loss_fn(x_hat, data)
loss.backward()
optimizer.step()
total_loss.append(loss.item())
avg_training_losses.append(sum(total_loss) / len(total_loss))
with torch.no_grad():
losses = []
for data in validation_loader:
x_hat = model(data)
loss = loss_fn(x_hat, data)
losses.append(loss.item())
avg_loss = sum(losses) / len(losses)
avg_validation_losses.append(avg_loss)
tqdm.write(f"epoch: {epoch + 1}, avg loss = {avg_loss}")
return avg_training_losses, avg_validation_losses

View file

@ -0,0 +1,22 @@
from abc import ABC, abstractmethod
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
class Trainer(ABC):
"""Abstract class for trainers."""
@abstractmethod
def execute(
self,
model: nn.Module | None,
train_loader: DataLoader,
validation_loader: DataLoader,
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
n_epochs: int,
device: str
) -> None:
pass

View file

@ -0,0 +1 @@
from .utils import *

View file

@ -14,6 +14,17 @@ def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
plt.hist(range(from_to[0], from_to[1]), weights=probabilities) plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
plt.show() plt.show()
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
plt.plot(train_losses, label="Training loss")
plt.plot(validation_losses, label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (cross entropy)")
plt.legend()
if show:
plt.show()
plt.savefig("losses.png")
def load_data(path: str) -> bytes: def load_data(path: str) -> bytes:
with open(path, "rb") as f: with open(path, "rb") as f:

View file

@ -36,7 +36,7 @@ def CDF_fn(pz, bin_width, variable_type, distribution_type):
bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None] bin_locations = torch.arange(-n_bins // 2, n_bins // 2)[None, None, None, None, :] + MEAN.cpu()[..., None]
bin_locations = bin_locations.float() * bin_width bin_locations = bin_locations.float() * bin_width
bin_locations = bin_locations.to(device=pz[0].device) bin_locations = bin_locations.to(device=pz[0].DEVICE)
pz = [param[:, :, :, :, None] for param in pz] pz = [param[:, :, :, :, None] for param in pz]
cdf = cdf_fn( cdf = cdf_fn(
@ -86,7 +86,7 @@ def decode_sample(
state, pz, variable_type, distribution_type, bin_width=1./256): state, pz, variable_type, distribution_type, bin_width=1./256):
state = rans.unflatten(state) state = rans.unflatten(state)
device = pz[0].device device = pz[0].DEVICE
size = pz[0].size()[0:4] size = pz[0].size()[0:4]
CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type) CDFs, MEAN = CDF_fn(pz, bin_width, variable_type, distribution_type)

View file

@ -190,7 +190,7 @@ def run(args, kwargs):
import models.Model as Model import models.Model as Model
model = Model.Model(args) model = Model.Model(args)
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") args.DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.set_temperature(args.temperature) model.set_temperature(args.temperature)
model.enable_hard_round(args.hard_round) model.enable_hard_round(args.hard_round)
@ -208,7 +208,7 @@ def run(args, kwargs):
print("Let's use", torch.cuda.device_count(), "GPUs!") print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model, dim=0) model = torch.nn.DataParallel(model, dim=0)
model.to(args.device) model.to(args.DEVICE)
def lr_lambda(epoch): def lr_lambda(epoch):
return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch) return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)

View file

@ -34,7 +34,7 @@ def _stacked_sigmoid(x, temperature, n_approx=3):
x_remainder = x_remainder.view(size + (1,)) x_remainder = x_remainder.view(size + (1,))
translation = torch.arange(n_approx) - n_approx // 2 translation = torch.arange(n_approx) - n_approx // 2
translation = translation.to(device=x.device, dtype=x.dtype) translation = translation.to(device=x.DEVICE, dtype=x.dtype)
translation = translation.view([1] * len(size) + [len(translation)]) translation = translation.view([1] * len(size) + [len(translation)])
out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1) out = torch.sigmoid((x_remainder - translation) / temperature).sum(dim=-1)

View file

@ -17,7 +17,7 @@ def train(epoch, train_loader, model, opt, args):
for batch_idx, (data, _) in enumerate(train_loader): for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(-1, *args.input_size) data = data.view(-1, *args.input_size)
data = data.to(args.device) data = data.to(args.DEVICE)
opt.zero_grad() opt.zero_grad()
loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data) loss, bpd, bpd_per_prior, pz, z, pys, py, ldj = model(data)

View file

@ -5,8 +5,12 @@ description = "Add your description here"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
"datasets>=4.4.1",
"lorem>=0.1.1",
"matplotlib>=3.10.7", "matplotlib>=3.10.7",
"numpy>=2.3.4", "numpy>=2.3.4",
"optuna>=4.5.0", "optuna>=4.5.0",
"torch>=2.9.0", "torch>=2.9.0",
"torchdata>=0.11.0",
"torchvision>=0.24.0",
] ]

View file

@ -176,9 +176,9 @@ class RelMultiHeadAttn(nn.Module):
def _shift(self, x, qlen, klen, mask, left=False): def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1: if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype) device=x.DEVICE, dtype=x.dtype)
else: else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype) zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype)
if left: if left:
mask = mask.flip(1) mask = mask.flip(1)
@ -193,7 +193,7 @@ class RelMultiHeadAttn(nn.Module):
def _rel_shift(self, x, zero_triu=False): def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype) device=x.DEVICE, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1) x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
@ -661,7 +661,7 @@ class MemTransformerLM(nn.Module):
hids = [] hids = []
if self.attn_type == 0: # default if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
dtype=word_emb.dtype) dtype=word_emb.dtype)
if self.clamp_len > 0: if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len) pos_seq.clamp_(max=self.clamp_len)
@ -691,7 +691,7 @@ class MemTransformerLM(nn.Module):
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out) hids.append(core_out)
elif self.attn_type == 2: # absolute elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device, pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
dtype=word_emb.dtype) dtype=word_emb.dtype)
if self.clamp_len > 0: if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len) pos_seq.clamp_(max=self.clamp_len)

View file

@ -160,7 +160,7 @@ np.random.seed(args.seed)
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
if not args.cuda: if not args.cuda:
print('WARNING: You have a CUDA device, so you should probably run with --cuda') print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda')
else: else:
torch.cuda.manual_seed_all(args.seed) torch.cuda.manual_seed_all(args.seed)

View file

@ -50,7 +50,7 @@ class AdaptiveLogSoftmax(nn.Module):
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target, nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device) dtype=hidden.dtype, device=hidden.DEVICE)
offset = 0 offset = 0
cutoff_values = [0] + self.cutoffs cutoff_values = [0] + self.cutoffs

View file

@ -38,7 +38,7 @@ class LogUniformSampler(object):
with torch.no_grad(): with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device device = labels.DEVICE
neg_samples = neg_samples.to(device) neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device) true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device) samp_log_probs = self.log_q[neg_samples].to(device)

View file

@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_logprob = F.log_softmax(head_logit, dim=1) head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target, nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device) dtype=hidden.dtype, device=hidden.DEVICE)
offset = 0 offset = 0
cutoff_values = [0] + self.cutoffs cutoff_values = [0] + self.cutoffs

View file

@ -1,7 +1,7 @@
import os import os
import tensorflow as tf import tensorflow as tf
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"): def assign_to_gpu(gpu=0, ps_dev="/DEVICE:CPU:0"):
def _assign(op): def _assign(op):
node_def = op if isinstance(op, tf.NodeDef) else op.node_def node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op == "Variable": if node_def.op == "Variable":

View file

@ -724,7 +724,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
hooks = [] hooks = []
with ops.device(device): with ops.DEVICE(device):
user_context = tpu_context.TPUContext( user_context = tpu_context.TPUContext(
internal_ctx=ctx, internal_ctx=ctx,
input_device=device, input_device=device,
@ -758,7 +758,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
Returns: Returns:
list of dict of ops. list of dict of ops.
""" """
with ops.device(device): with ops.DEVICE(device):
num_of_replicas_per_host = ctx.num_of_replicas_per_host num_of_replicas_per_host = ctx.num_of_replicas_per_host
# Convert user input to features and labels. If the user returns a # Convert user input to features and labels. If the user returns a
# dataset, it is initialized and the features and labels extracted via # dataset, it is initialized and the features and labels extracted via
@ -799,7 +799,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
captured_infeed_queue = _CapturedObject() captured_infeed_queue = _CapturedObject()
hooks = [] hooks = []
with ops.device(device): with ops.DEVICE(device):
user_context = tpu_context.TPUContext( user_context = tpu_context.TPUContext(
internal_ctx=ctx, internal_ctx=ctx,
input_device=device, input_device=device,
@ -827,7 +827,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
per_host_sharded_inputs = [] per_host_sharded_inputs = []
num_replicas_per_host = ctx.num_of_replicas_per_host num_replicas_per_host = ctx.num_of_replicas_per_host
cached_signals = None cached_signals = None
with ops.device(device): with ops.DEVICE(device):
if not inputs.is_dataset: if not inputs.is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for this mode.') raise TypeError('`input_fn` must return a `Dataset` for this mode.')
for _ in range(num_replicas_per_host): for _ in range(num_replicas_per_host):
@ -888,7 +888,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
captured_infeed_queue = _CapturedObject() captured_infeed_queue = _CapturedObject()
hooks = [] hooks = []
device_0 = ctx.tpu_host_placement_function(host_id=0) device_0 = ctx.tpu_host_placement_function(host_id=0)
with ops.device(device_0): with ops.DEVICE(device_0):
user_context = tpu_context.TPUContext( user_context = tpu_context.TPUContext(
internal_ctx=ctx, input_device=device_0, invocation_index=0) internal_ctx=ctx, input_device=device_0, invocation_index=0)
inputs = _Inputs.from_input_fn(input_fn(user_context)) inputs = _Inputs.from_input_fn(input_fn(user_context))
@ -924,7 +924,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
flattened_inputs = None # Cache result from input_fn. flattened_inputs = None # Cache result from input_fn.
signals = None signals = None
for host_id in xrange(num_hosts): for host_id in xrange(num_hosts):
with ops.device(ctx.tpu_host_placement_function(host_id=host_id)): with ops.DEVICE(ctx.tpu_host_placement_function(host_id=host_id)):
for _ in xrange(ctx.num_of_replicas_per_host): for _ in xrange(ctx.num_of_replicas_per_host):
# Note: input_fn is only called once at host 0 for the first replica. # Note: input_fn is only called once at host 0 for the first replica.
# The features and labels returned from that invocation are # The features and labels returned from that invocation are
@ -1147,7 +1147,7 @@ class _InputPipeline(object):
def dequeue_fn(): def dequeue_fn():
"""dequeue_fn is used by TPU to retrieve the tensors.""" """dequeue_fn is used by TPU to retrieve the tensors."""
# In the model-parallel case, both the host-side and device-side # In the model-parallel case, both the host-side and DEVICE-side
# computations must agree on the core on which infeed takes place. We # computations must agree on the core on which infeed takes place. We
# choose to perform infeed on logical core 0 of each replica. # choose to perform infeed on logical core 0 of each replica.
values = self._infeed_queue.generate_dequeue_op(tpu_device=0) values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
@ -1173,7 +1173,7 @@ class _InputPipeline(object):
# host. # host.
for host_id in range(num_hosts): for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id) host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device): with ops.DEVICE(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)): with ops.name_scope('input_pipeline_task%d' % (host_id)):
enqueue_ops_fn, captured_infeed_queue = ( enqueue_ops_fn, captured_infeed_queue = (
generate_per_core_enqueue_ops_fn_for_host( generate_per_core_enqueue_ops_fn_for_host(
@ -1211,7 +1211,7 @@ class _InputPipeline(object):
else: else:
for host_id in range(num_hosts): for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id) host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device): with ops.DEVICE(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)): with ops.name_scope('input_pipeline_task%d' % (host_id)):
if self._ctx.is_input_per_host_with_iterators(): if self._ctx.is_input_per_host_with_iterators():
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
@ -1712,7 +1712,7 @@ class _OutfeedHostCall(object):
for name in self._names: for name in self._names:
tensors.extend(self._tensors[name]) tensors.extend(self._tensors[name])
with ops.device(tpu.core(0)): with ops.DEVICE(tpu.core(0)):
return [tpu_ops.outfeed_enqueue_tuple(tensors)] return [tpu_ops.outfeed_enqueue_tuple(tensors)]
def create_tpu_hostcall(self): def create_tpu_hostcall(self):
@ -1751,7 +1751,7 @@ class _OutfeedHostCall(object):
# per replica. # per replica.
for i in xrange(self._ctx.num_replicas): for i in xrange(self._ctx.num_replicas):
host_device, ordinal_id = self._ctx.device_for_replica(i) host_device, ordinal_id = self._ctx.device_for_replica(i)
with ops.device(host_device): with ops.DEVICE(host_device):
outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
dtypes=tensor_dtypes, dtypes=tensor_dtypes,
shapes=tensor_shapes, shapes=tensor_shapes,
@ -1770,7 +1770,7 @@ class _OutfeedHostCall(object):
# place all ops on tpu host if possible. # place all ops on tpu host if possible.
# #
# TODO(jhseu): Evaluate whether this is right for summaries. # TODO(jhseu): Evaluate whether this is right for summaries.
with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)): with ops.DEVICE(self._ctx.tpu_host_placement_function(replica_id=0)):
for name in self._names: for name in self._names:
dequeue_ops = dequeue_ops_by_name[name] dequeue_ops = dequeue_ops_by_name[name]
for i, item in enumerate(dequeue_ops): for i, item in enumerate(dequeue_ops):
@ -2426,7 +2426,7 @@ class TPUEstimator(estimator_lib.Estimator):
# For export_savedmodel, input_fn is never passed to Estimator. So, # For export_savedmodel, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False. # `is_export_mode` must be False.
if ctx.is_running_on_cpu(is_export_mode=False): if ctx.is_running_on_cpu(is_export_mode=False):
with ops.device('/device:CPU:0'): with ops.DEVICE('/DEVICE:CPU:0'):
return input_fn(**kwargs) return input_fn(**kwargs)
# For TPU computation, input_fn should be invoked in a tf.while_loop for # For TPU computation, input_fn should be invoked in a tf.while_loop for
@ -2971,7 +2971,7 @@ def _wrap_computation_in_while_loop(device, op_fn):
iterations_per_loop_var = _create_or_get_iterations_per_loop() iterations_per_loop_var = _create_or_get_iterations_per_loop()
# By setting parallel_iterations=1, the parallel execution in while_loop is # By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off. # basically turned off.
with ops.device(device): with ops.DEVICE(device):
iterations = array_ops.identity(iterations_per_loop_var) iterations = array_ops.identity(iterations_per_loop_var)
return control_flow_ops.while_loop( return control_flow_ops.while_loop(
lambda i: i < iterations, lambda i: i < iterations,
@ -2995,7 +2995,7 @@ def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
# By setting parallel_iterations=1, the parallel execution in while_loop is # By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off. # basically turned off.
with ops.device(device): with ops.DEVICE(device):
return control_flow_ops.while_loop( return control_flow_ops.while_loop(
cond, cond,
computation, [_StopSignals.NON_STOPPING_SIGNAL], computation, [_StopSignals.NON_STOPPING_SIGNAL],
@ -3006,7 +3006,7 @@ def _validate_tpu_training_graph():
"""Validate graph before running distributed training. """Validate graph before running distributed training.
Raises: Raises:
ValueError: If the graph seems invalid for running on device ValueError: If the graph seems invalid for running on DEVICE
""" """
operations = ops.get_default_graph().get_operations() operations = ops.get_default_graph().get_operations()

View file

@ -249,7 +249,7 @@ def train(n_token, cutoffs, ps_device):
for i in range(FLAGS.num_core_per_host): for i in range(FLAGS.num_core_per_host):
reuse = True if i > 0 else None reuse = True if i > 0 else None
with tf.device(assign_to_gpu(i, ps_device)), \ with tf.DEVICE(assign_to_gpu(i, ps_device)), \
tf.variable_scope(tf.get_variable_scope(), reuse=reuse): tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
mems_i = [tf.placeholder(tf.float32, mems_i = [tf.placeholder(tf.float32,
@ -384,7 +384,7 @@ def evaluate(n_token, cutoffs, ps_device):
tower_mems, tower_losses, tower_new_mems = [], [], [] tower_mems, tower_losses, tower_new_mems = [], [], []
for i in range(FLAGS.num_core_per_host): for i in range(FLAGS.num_core_per_host):
with tf.device(assign_to_gpu(i, ps_device)), \ with tf.DEVICE(assign_to_gpu(i, ps_device)), \
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
mems_i = [tf.placeholder(tf.float32, mems_i = [tf.placeholder(tf.float32,

1034
uv.lock generated

File diff suppressed because it is too large Load diff