feat: uhm, i changed some things

This commit is contained in:
RobinMeersman 2025-11-25 20:20:08 +01:00
parent b58682cb49
commit 6de4db24cc
27 changed files with 1302 additions and 137 deletions

View file

@ -176,9 +176,9 @@ class RelMultiHeadAttn(nn.Module):
def _shift(self, x, qlen, klen, mask, left=False):
if qlen > 1:
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
device=x.device, dtype=x.dtype)
device=x.DEVICE, dtype=x.dtype)
else:
zero_pad = torch.zeros(0, device=x.device, dtype=x.dtype)
zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype)
if left:
mask = mask.flip(1)
@ -193,7 +193,7 @@ class RelMultiHeadAttn(nn.Module):
def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
device=x.device, dtype=x.dtype)
device=x.DEVICE, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
@ -661,7 +661,7 @@ class MemTransformerLM(nn.Module):
hids = []
if self.attn_type == 0: # default
pos_seq = torch.arange(klen-1, -1, -1.0, device=word_emb.device,
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)
@ -691,7 +691,7 @@ class MemTransformerLM(nn.Module):
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
hids.append(core_out)
elif self.attn_type == 2: # absolute
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.device,
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
dtype=word_emb.dtype)
if self.clamp_len > 0:
pos_seq.clamp_(max=self.clamp_len)

View file

@ -160,7 +160,7 @@ np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print('WARNING: You have a CUDA device, so you should probably run with --cuda')
print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda')
else:
torch.cuda.manual_seed_all(args.seed)

View file

@ -50,7 +50,7 @@ class AdaptiveLogSoftmax(nn.Module):
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
dtype=hidden.dtype, device=hidden.DEVICE)
offset = 0
cutoff_values = [0] + self.cutoffs

View file

@ -38,7 +38,7 @@ class LogUniformSampler(object):
with torch.no_grad():
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
device = labels.device
device = labels.DEVICE
neg_samples = neg_samples.to(device)
true_log_probs = self.log_q[labels].to(device)
samp_log_probs = self.log_q[neg_samples].to(device)

View file

@ -112,7 +112,7 @@ class ProjectedAdaptiveLogSoftmax(nn.Module):
head_logprob = F.log_softmax(head_logit, dim=1)
nll = torch.zeros_like(target,
dtype=hidden.dtype, device=hidden.device)
dtype=hidden.dtype, device=hidden.DEVICE)
offset = 0
cutoff_values = [0] + self.cutoffs

View file

@ -1,7 +1,7 @@
import os
import tensorflow as tf
def assign_to_gpu(gpu=0, ps_dev="/device:CPU:0"):
def assign_to_gpu(gpu=0, ps_dev="/DEVICE:CPU:0"):
def _assign(op):
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op == "Variable":

View file

@ -724,7 +724,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
hooks = []
with ops.device(device):
with ops.DEVICE(device):
user_context = tpu_context.TPUContext(
internal_ctx=ctx,
input_device=device,
@ -758,7 +758,7 @@ def generate_per_host_enqueue_ops_fn_for_host(
Returns:
list of dict of ops.
"""
with ops.device(device):
with ops.DEVICE(device):
num_of_replicas_per_host = ctx.num_of_replicas_per_host
# Convert user input to features and labels. If the user returns a
# dataset, it is initialized and the features and labels extracted via
@ -799,7 +799,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
captured_infeed_queue = _CapturedObject()
hooks = []
with ops.device(device):
with ops.DEVICE(device):
user_context = tpu_context.TPUContext(
internal_ctx=ctx,
input_device=device,
@ -827,7 +827,7 @@ def generate_per_host_v2_enqueue_ops_fn_for_host(
per_host_sharded_inputs = []
num_replicas_per_host = ctx.num_of_replicas_per_host
cached_signals = None
with ops.device(device):
with ops.DEVICE(device):
if not inputs.is_dataset:
raise TypeError('`input_fn` must return a `Dataset` for this mode.')
for _ in range(num_replicas_per_host):
@ -888,7 +888,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
captured_infeed_queue = _CapturedObject()
hooks = []
device_0 = ctx.tpu_host_placement_function(host_id=0)
with ops.device(device_0):
with ops.DEVICE(device_0):
user_context = tpu_context.TPUContext(
internal_ctx=ctx, input_device=device_0, invocation_index=0)
inputs = _Inputs.from_input_fn(input_fn(user_context))
@ -924,7 +924,7 @@ def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder,
flattened_inputs = None # Cache result from input_fn.
signals = None
for host_id in xrange(num_hosts):
with ops.device(ctx.tpu_host_placement_function(host_id=host_id)):
with ops.DEVICE(ctx.tpu_host_placement_function(host_id=host_id)):
for _ in xrange(ctx.num_of_replicas_per_host):
# Note: input_fn is only called once at host 0 for the first replica.
# The features and labels returned from that invocation are
@ -1147,7 +1147,7 @@ class _InputPipeline(object):
def dequeue_fn():
"""dequeue_fn is used by TPU to retrieve the tensors."""
# In the model-parallel case, both the host-side and device-side
# In the model-parallel case, both the host-side and DEVICE-side
# computations must agree on the core on which infeed takes place. We
# choose to perform infeed on logical core 0 of each replica.
values = self._infeed_queue.generate_dequeue_op(tpu_device=0)
@ -1173,7 +1173,7 @@ class _InputPipeline(object):
# host.
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.DEVICE(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
enqueue_ops_fn, captured_infeed_queue = (
generate_per_core_enqueue_ops_fn_for_host(
@ -1211,7 +1211,7 @@ class _InputPipeline(object):
else:
for host_id in range(num_hosts):
host_device = tpu_host_placement_fn(host_id=host_id)
with ops.device(host_device):
with ops.DEVICE(host_device):
with ops.name_scope('input_pipeline_task%d' % (host_id)):
if self._ctx.is_input_per_host_with_iterators():
enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = (
@ -1712,7 +1712,7 @@ class _OutfeedHostCall(object):
for name in self._names:
tensors.extend(self._tensors[name])
with ops.device(tpu.core(0)):
with ops.DEVICE(tpu.core(0)):
return [tpu_ops.outfeed_enqueue_tuple(tensors)]
def create_tpu_hostcall(self):
@ -1751,7 +1751,7 @@ class _OutfeedHostCall(object):
# per replica.
for i in xrange(self._ctx.num_replicas):
host_device, ordinal_id = self._ctx.device_for_replica(i)
with ops.device(host_device):
with ops.DEVICE(host_device):
outfeed_tensors = tpu_ops.outfeed_dequeue_tuple(
dtypes=tensor_dtypes,
shapes=tensor_shapes,
@ -1770,7 +1770,7 @@ class _OutfeedHostCall(object):
# place all ops on tpu host if possible.
#
# TODO(jhseu): Evaluate whether this is right for summaries.
with ops.device(self._ctx.tpu_host_placement_function(replica_id=0)):
with ops.DEVICE(self._ctx.tpu_host_placement_function(replica_id=0)):
for name in self._names:
dequeue_ops = dequeue_ops_by_name[name]
for i, item in enumerate(dequeue_ops):
@ -2426,7 +2426,7 @@ class TPUEstimator(estimator_lib.Estimator):
# For export_savedmodel, input_fn is never passed to Estimator. So,
# `is_export_mode` must be False.
if ctx.is_running_on_cpu(is_export_mode=False):
with ops.device('/device:CPU:0'):
with ops.DEVICE('/DEVICE:CPU:0'):
return input_fn(**kwargs)
# For TPU computation, input_fn should be invoked in a tf.while_loop for
@ -2971,7 +2971,7 @@ def _wrap_computation_in_while_loop(device, op_fn):
iterations_per_loop_var = _create_or_get_iterations_per_loop()
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with ops.device(device):
with ops.DEVICE(device):
iterations = array_ops.identity(iterations_per_loop_var)
return control_flow_ops.while_loop(
lambda i: i < iterations,
@ -2995,7 +2995,7 @@ def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
# By setting parallel_iterations=1, the parallel execution in while_loop is
# basically turned off.
with ops.device(device):
with ops.DEVICE(device):
return control_flow_ops.while_loop(
cond,
computation, [_StopSignals.NON_STOPPING_SIGNAL],
@ -3006,7 +3006,7 @@ def _validate_tpu_training_graph():
"""Validate graph before running distributed training.
Raises:
ValueError: If the graph seems invalid for running on device
ValueError: If the graph seems invalid for running on DEVICE
"""
operations = ops.get_default_graph().get_operations()

View file

@ -249,7 +249,7 @@ def train(n_token, cutoffs, ps_device):
for i in range(FLAGS.num_core_per_host):
reuse = True if i > 0 else None
with tf.device(assign_to_gpu(i, ps_device)), \
with tf.DEVICE(assign_to_gpu(i, ps_device)), \
tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
mems_i = [tf.placeholder(tf.float32,
@ -384,7 +384,7 @@ def evaluate(n_token, cutoffs, ps_device):
tower_mems, tower_losses, tower_new_mems = [], [], []
for i in range(FLAGS.num_core_per_host):
with tf.device(assign_to_gpu(i, ps_device)), \
with tf.DEVICE(assign_to_gpu(i, ps_device)), \
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
mems_i = [tf.placeholder(tf.float32,