chore(transformer-xl): Initial commit
This commit is contained in:
parent
ef4684ef39
commit
10512876f2
46 changed files with 10547 additions and 0 deletions
462
transformer-xl/tf/train.py
Normal file
462
transformer-xl/tf/train.py
Normal file
|
|
@ -0,0 +1,462 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import absl.logging as _logging # pylint: disable=unused-import
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.gfile import Exists as exists
|
||||
import model
|
||||
import data_utils
|
||||
import tpu_estimator
|
||||
|
||||
import numpy as np
|
||||
from time import sleep
|
||||
|
||||
|
||||
# TPU parameters
|
||||
flags.DEFINE_string("master", default=None,
|
||||
help="master")
|
||||
flags.DEFINE_string("tpu", default=None,
|
||||
help="The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
|
||||
flags.DEFINE_string("gcp_project", default=None,
|
||||
help="Project name for the Cloud TPU-enabled project. If not specified, "
|
||||
"we will attempt to automatically detect the GCE project from metadata.")
|
||||
flags.DEFINE_string("tpu_zone",default=None,
|
||||
help="GCE zone where the Cloud TPU is located in. If not specified, we "
|
||||
"will attempt to automatically detect the GCE project from metadata.")
|
||||
flags.DEFINE_bool("use_tpu", default=True,
|
||||
help="Use TPUs rather than plain CPUs.")
|
||||
flags.DEFINE_integer("num_hosts", default=1,
|
||||
help="number of TPU hosts")
|
||||
flags.DEFINE_integer("num_core_per_host", default=8,
|
||||
help="number of cores per host")
|
||||
|
||||
# Experiment (data/checkpoint/directory) parameters
|
||||
flags.DEFINE_string("data_dir", default="",
|
||||
help="Path to tf-records directory.")
|
||||
flags.DEFINE_string("record_info_dir", default="",
|
||||
help="Path to local directory containing filenames.txt.")
|
||||
flags.DEFINE_string("corpus_info_path", default="",
|
||||
help="Path to corpus-info.json file.")
|
||||
flags.DEFINE_string("model_dir", default=None,
|
||||
help="Estimator model_dir.")
|
||||
flags.DEFINE_bool("do_eval", default=False,
|
||||
help="Whether to run eval on the dev set.")
|
||||
flags.DEFINE_bool("track_mean", default=True,
|
||||
help="Trace mean loss during training.")
|
||||
flags.DEFINE_string("eval_ckpt_path", None,
|
||||
help="Checkpoint path for evaluation."
|
||||
"If set, model_dir will be ignored."
|
||||
"If unset, will use the latest ckpt in model_dir.")
|
||||
flags.DEFINE_string("warm_start_path", None,
|
||||
help="Checkpoint path for warm start."
|
||||
"If set, will clear Adam states."
|
||||
"Note that the new model_dir should be different"
|
||||
" from warm_start_path.")
|
||||
|
||||
# Optimization paramenters
|
||||
flags.DEFINE_float("learning_rate", default=2.5e-4,
|
||||
help="Maximum learning rate.")
|
||||
flags.DEFINE_float("clip", default=0.25,
|
||||
help="Gradient clipping value.")
|
||||
# for cosine decay
|
||||
flags.DEFINE_float("min_lr_ratio", default=0.01,
|
||||
help="Minimum ratio learning rate.")
|
||||
flags.DEFINE_integer("warmup_steps", default=0,
|
||||
help="Number of steps for linear lr warmup.")
|
||||
|
||||
# Training parameters
|
||||
flags.DEFINE_integer("train_batch_size", default=60,
|
||||
help="Size of train batch.")
|
||||
flags.DEFINE_integer("eval_batch_size", default=60,
|
||||
help="Size of valid batch.")
|
||||
flags.DEFINE_integer("train_steps", default=100000,
|
||||
help="Total number of training steps.")
|
||||
flags.DEFINE_integer("iterations", default=500,
|
||||
help="Number of iterations per repeat loop.")
|
||||
flags.DEFINE_integer("save_steps", default=10000,
|
||||
help="number of steps for model checkpointing.")
|
||||
|
||||
# Evaluation parameters
|
||||
flags.DEFINE_integer("max_eval_batch", default=-1,
|
||||
help="Set -1 to turn off. Only used in test mode.")
|
||||
flags.DEFINE_bool("do_eval_only", default=False,
|
||||
help="Run evaluation only.")
|
||||
flags.DEFINE_integer("start_eval_steps", default=10000,
|
||||
help="Which checkpoint to start with in `do_eval_only` mode.")
|
||||
flags.DEFINE_string("eval_split", "valid",
|
||||
help="Which data split to evaluate.")
|
||||
|
||||
# Model paramenters
|
||||
flags.DEFINE_integer("tgt_len", default=70,
|
||||
help="Number of steps to predict")
|
||||
flags.DEFINE_integer("mem_len", default=70,
|
||||
help="Number of steps to cache")
|
||||
flags.DEFINE_bool("same_length", default=False,
|
||||
help="Same length attention")
|
||||
flags.DEFINE_integer("clamp_len", default=-1,
|
||||
help="Clamp length")
|
||||
|
||||
flags.DEFINE_integer("n_layer", default=6,
|
||||
help="Number of layers.")
|
||||
flags.DEFINE_integer("d_model", default=500,
|
||||
help="Dimension of the model.")
|
||||
flags.DEFINE_integer("d_embed", default=500,
|
||||
help="Dimension of the embeddings.")
|
||||
flags.DEFINE_integer("n_head", default=10,
|
||||
help="Number of attention heads.")
|
||||
flags.DEFINE_integer("d_head", default=50,
|
||||
help="Dimension of each attention head.")
|
||||
flags.DEFINE_integer("d_inner", default=1000,
|
||||
help="Dimension of inner hidden size in positionwise feed-forward.")
|
||||
flags.DEFINE_float("dropout", default=0.1,
|
||||
help="Dropout rate.")
|
||||
flags.DEFINE_float("dropatt", default=0.1,
|
||||
help="Attention dropout rate.")
|
||||
flags.DEFINE_bool("untie_r", default=False,
|
||||
help="untie r_w_bias and r_r_bias")
|
||||
|
||||
# Adaptive Softmax / Embedding
|
||||
flags.DEFINE_bool("tie_weight", default=True,
|
||||
help="Tie embedding and softmax weight.")
|
||||
flags.DEFINE_integer("div_val", default=1,
|
||||
help="Divide the embedding size by this val for each bin")
|
||||
flags.DEFINE_bool("proj_share_all_but_first", default=False,
|
||||
help="True to share all but first projs, False not to share.")
|
||||
flags.DEFINE_bool("proj_same_dim", default=True,
|
||||
help="Project the bin with the same dimension.")
|
||||
|
||||
# Parameter initialization
|
||||
flags.DEFINE_enum("init", default="normal",
|
||||
enum_values=["normal", "uniform"],
|
||||
help="Initialization method.")
|
||||
flags.DEFINE_float("init_std", default=0.02,
|
||||
help="Initialization std when init is normal.")
|
||||
flags.DEFINE_float("proj_init_std", default=0.01,
|
||||
help="Initialization std for embedding projection.")
|
||||
flags.DEFINE_float("init_range", default=0.1,
|
||||
help="Initialization std when init is uniform.")
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def metric_fn(loss):
|
||||
"""Evaluation metric Fn which runs on CPU."""
|
||||
perplexity = tf.exp(tf.reduce_mean(loss))
|
||||
bpc = tf.reduce_mean(loss) / tf.constant(math.log(2))
|
||||
return {
|
||||
"perplexity": tf.metrics.mean(perplexity),
|
||||
"bpc": tf.metrics.mean(bpc),
|
||||
}
|
||||
|
||||
|
||||
def get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes):
|
||||
def model_fn(features, labels, mode, params):
|
||||
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||
|
||||
|
||||
batch_size = params["batch_size"]
|
||||
|
||||
mems = params["cache"]
|
||||
inp = tf.transpose(features["inputs"], [1, 0])
|
||||
tgt = tf.transpose(features["labels"], [1, 0])
|
||||
|
||||
bin_sizes = train_bin_sizes if is_training else eval_bin_sizes
|
||||
if bin_sizes:
|
||||
inp_perms = [tf.transpose(features["inp_mask"], [1, 0])]
|
||||
tgt_perms = [tf.transpose(features["tgt_mask"], [1, 0])]
|
||||
|
||||
head_tgt = tf.transpose(features["head_labels"], [1, 0])
|
||||
|
||||
for b in range(len(bin_sizes)):
|
||||
inp_perm = tf.transpose(features["inp_perm_{}".format(b)], [1, 0, 2])
|
||||
tgt_perm = tf.transpose(features["tgt_perm_{}".format(b)], [1, 0, 2])
|
||||
|
||||
inp_perms.append(inp_perm)
|
||||
tgt_perms.append(tgt_perm)
|
||||
else:
|
||||
inp_perms, tgt_perms, head_tgt = None, None, None
|
||||
|
||||
if FLAGS.init == "uniform":
|
||||
initializer = tf.initializers.random_uniform(
|
||||
minval=-FLAGS.init_range,
|
||||
maxval=FLAGS.init_range,
|
||||
seed=None)
|
||||
elif FLAGS.init == "normal":
|
||||
initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.init_std,
|
||||
seed=None)
|
||||
proj_initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.proj_init_std,
|
||||
seed=None)
|
||||
|
||||
tie_projs = [False for _ in range(len(cutoffs) + 1)]
|
||||
if FLAGS.proj_share_all_but_first:
|
||||
for i in range(1, len(tie_projs)):
|
||||
tie_projs[i] = True
|
||||
|
||||
tf.logging.info("Vocab size : {}".format(n_token))
|
||||
tf.logging.info("Batch size : {}".format(batch_size))
|
||||
|
||||
loss, new_mems = model.transformer(
|
||||
dec_inp=inp,
|
||||
target=tgt,
|
||||
mems=mems,
|
||||
n_token=n_token,
|
||||
n_layer=FLAGS.n_layer,
|
||||
d_model=FLAGS.d_model,
|
||||
d_embed=FLAGS.d_embed,
|
||||
n_head=FLAGS.n_head,
|
||||
d_head=FLAGS.d_head,
|
||||
d_inner=FLAGS.d_inner,
|
||||
dropout=FLAGS.dropout,
|
||||
dropatt=FLAGS.dropatt,
|
||||
initializer=initializer,
|
||||
is_training=is_training,
|
||||
mem_len=FLAGS.mem_len,
|
||||
cutoffs=cutoffs,
|
||||
div_val=FLAGS.div_val,
|
||||
tie_projs=tie_projs,
|
||||
input_perms=inp_perms,
|
||||
target_perms=tgt_perms,
|
||||
head_target=head_tgt,
|
||||
same_length=FLAGS.same_length,
|
||||
clamp_len=FLAGS.clamp_len,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
untie_r=FLAGS.untie_r,
|
||||
proj_same_dim=FLAGS.proj_same_dim)
|
||||
|
||||
total_loss = tf.reduce_mean(loss)
|
||||
|
||||
if mode == tf.estimator.ModeKeys.EVAL:
|
||||
if FLAGS.use_tpu:
|
||||
with tf.colocate_with(total_loss):
|
||||
total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) \
|
||||
/ FLAGS.num_hosts / FLAGS.num_core_per_host
|
||||
metric_loss = tf.tile(tf.reshape(total_loss, [1, 1]), [batch_size, 1])
|
||||
eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
eval_metrics=(metric_fn, [metric_loss]))
|
||||
|
||||
eval_spec.cache = new_mems
|
||||
|
||||
return eval_spec
|
||||
|
||||
# Configuring the optimization step.
|
||||
global_step = tf.train.get_global_step()
|
||||
|
||||
# increase the learning rate linearly
|
||||
if FLAGS.warmup_steps > 0:
|
||||
warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
|
||||
* FLAGS.learning_rate
|
||||
else:
|
||||
warmup_lr = 0.0
|
||||
|
||||
# number of parameters
|
||||
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
|
||||
tf.logging.info("#params: {}".format(num_params))
|
||||
|
||||
# format_str = '{{:<{0}s}}\t{{}}'.format(
|
||||
# max([len(v.name) for v in tf.trainable_variables()]))
|
||||
# for v in tf.trainable_variables():
|
||||
# tf.logging.info(format_str.format(v.name, v.get_shape()))
|
||||
|
||||
|
||||
# decay the learning rate using the cosine schedule
|
||||
decay_lr = tf.train.cosine_decay(
|
||||
FLAGS.learning_rate,
|
||||
global_step=global_step-FLAGS.warmup_steps,
|
||||
decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
|
||||
alpha=FLAGS.min_lr_ratio)
|
||||
|
||||
learning_rate = tf.where(global_step < FLAGS.warmup_steps,
|
||||
warmup_lr, decay_lr)
|
||||
|
||||
if FLAGS.use_tpu:
|
||||
optimizer = tf.contrib.tpu.CrossShardOptimizer(
|
||||
tf.train.AdamOptimizer(learning_rate=learning_rate))
|
||||
#GradientDescentOptimizer
|
||||
else:
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
|
||||
|
||||
grads_and_vars = optimizer.compute_gradients(total_loss)
|
||||
gradients, variables = zip(*grads_and_vars)
|
||||
clipped, _ = tf.clip_by_global_norm(gradients, FLAGS.clip)
|
||||
train_op = optimizer.apply_gradients(
|
||||
zip(clipped, variables), global_step=tf.train.get_global_step())
|
||||
|
||||
# Constucting TPUEstimatorSpec with cache.
|
||||
train_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode, loss=total_loss, train_op=train_op)
|
||||
|
||||
if FLAGS.mem_len < FLAGS.tgt_len:
|
||||
new_mems = [new_mems[: FLAGS.mem_len] for mem_t in new_mems]
|
||||
train_spec.cache = new_mems
|
||||
|
||||
return train_spec
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_cache_fn(mem_len):
|
||||
|
||||
def cache_fn(batch_size):
|
||||
mems = []
|
||||
for l in xrange(FLAGS.n_layer):
|
||||
if mem_len > 0:
|
||||
mems.append(
|
||||
tf.zeros([mem_len, batch_size, FLAGS.d_model], dtype=tf.float32))
|
||||
else:
|
||||
mems.append(tf.zeros([mem_len], dtype=tf.float32))
|
||||
|
||||
return mems
|
||||
|
||||
return cache_fn
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv # Unused
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
# Get corpus info
|
||||
corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
|
||||
n_token = corpus_info["vocab_size"]
|
||||
cutoffs = corpus_info["cutoffs"][1:-1]
|
||||
|
||||
if FLAGS.save_steps == 0:
|
||||
FLAGS.save_steps = None
|
||||
|
||||
if not FLAGS.do_eval_only:
|
||||
# Get train input function
|
||||
train_input_fn, train_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split="train",
|
||||
per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=FLAGS.num_hosts,
|
||||
use_tpu=FLAGS.use_tpu)
|
||||
train_bin_sizes = train_record_info["bin_sizes"]
|
||||
num_train_batch = train_record_info["num_batch"]
|
||||
|
||||
# Get train cache function
|
||||
train_cache_fn = get_cache_fn(FLAGS.mem_len)
|
||||
else:
|
||||
train_bin_sizes = []
|
||||
num_train_batch = None
|
||||
train_cache_fn = None
|
||||
|
||||
if FLAGS.do_eval or FLAGS.do_eval_only:
|
||||
assert FLAGS.num_hosts == 1
|
||||
# Get eval input function
|
||||
eval_input_fn, eval_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split=FLAGS.eval_split,
|
||||
per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=FLAGS.num_hosts,
|
||||
use_tpu=FLAGS.use_tpu)
|
||||
eval_bin_sizes = eval_record_info["bin_sizes"]
|
||||
num_eval_batch = eval_record_info["num_batch"]
|
||||
|
||||
if FLAGS.max_eval_batch > 0:
|
||||
num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch)
|
||||
|
||||
# Get eval cache function
|
||||
eval_cache_fn = get_cache_fn(FLAGS.mem_len)
|
||||
model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes)
|
||||
else:
|
||||
eval_cache_fn = None
|
||||
model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, [])
|
||||
|
||||
##### Create estimator
|
||||
# TPU Configuration
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||
|
||||
per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
model_dir=FLAGS.model_dir,
|
||||
session_config=tf.ConfigProto(
|
||||
allow_soft_placement=True, log_device_placement=True),
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=FLAGS.iterations,
|
||||
num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts,
|
||||
per_host_input_for_training=per_host_input),
|
||||
keep_checkpoint_max=100000, # effectively save all checkpoints
|
||||
save_checkpoints_secs=None,
|
||||
save_checkpoints_steps=FLAGS.save_steps
|
||||
)
|
||||
|
||||
# warm start
|
||||
warm_start_from = None
|
||||
if FLAGS.warm_start_path is not None:
|
||||
warm_start_from = tf.estimator.WarmStartSettings(
|
||||
ckpt_to_initialize_from=FLAGS.warm_start_path)
|
||||
|
||||
# TPU Estimator
|
||||
estimator = tpu_estimator.TPUEstimator(
|
||||
model_fn=model_fn,
|
||||
train_cache_fn=train_cache_fn,
|
||||
eval_cache_fn=eval_cache_fn,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
config=run_config,
|
||||
params={"data_dir":FLAGS.data_dir, "track_mean":FLAGS.track_mean},
|
||||
train_batch_size=FLAGS.train_batch_size,
|
||||
eval_batch_size=FLAGS.eval_batch_size,
|
||||
warm_start_from=warm_start_from)
|
||||
|
||||
if FLAGS.do_eval_only:
|
||||
if FLAGS.eval_ckpt_path is not None:
|
||||
ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch,
|
||||
checkpoint_path=FLAGS.eval_ckpt_path)
|
||||
tf.logging.info("=" * 200)
|
||||
log_str = "Eval results | "
|
||||
for key, val in ret.items():
|
||||
log_str += "{} {} | ".format(key, val)
|
||||
tf.logging.info(log_str)
|
||||
tf.logging.info("=" * 200)
|
||||
else:
|
||||
ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
|
||||
eval_results = []
|
||||
for eval_checkpoint in ckpt_state.all_model_checkpoint_paths:
|
||||
if not exists(eval_checkpoint + ".index"): continue
|
||||
global_step = int(eval_checkpoint.split("-")[-1])
|
||||
if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps:
|
||||
continue
|
||||
ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch,
|
||||
checkpoint_path=eval_checkpoint)
|
||||
eval_results.append(ret)
|
||||
|
||||
eval_results.sort(key = lambda x: x["perplexity"])
|
||||
|
||||
tf.logging.info("=" * 200)
|
||||
log_str = "Best results | "
|
||||
for key, val in eval_results[0].items():
|
||||
log_str += "{} {} | ".format(key, val)
|
||||
tf.logging.info(log_str)
|
||||
tf.logging.info("=" * 200)
|
||||
else:
|
||||
if not FLAGS.do_eval:
|
||||
estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)
|
||||
else:
|
||||
for step in range(0, FLAGS.train_steps, num_train_batch):
|
||||
train_steps = min(FLAGS.train_steps - step, num_train_batch)
|
||||
estimator.train(input_fn=train_input_fn, steps=train_steps)
|
||||
estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
Reference in a new issue