chore(transformer-xl): Initial commit
This commit is contained in:
parent
ef4684ef39
commit
10512876f2
46 changed files with 10547 additions and 0 deletions
475
transformer-xl/tf/train_gpu.py
Normal file
475
transformer-xl/tf/train_gpu.py
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import absl.logging as _logging # pylint: disable=unused-import
|
||||
|
||||
import tensorflow as tf
|
||||
import model
|
||||
import data_utils
|
||||
|
||||
from gpu_utils import assign_to_gpu, average_grads_and_vars
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GPU config
|
||||
flags.DEFINE_integer("num_hosts", default=1,
|
||||
help="Number of TPU hosts")
|
||||
flags.DEFINE_integer("num_core_per_host", default=8,
|
||||
help="Number of cores per host")
|
||||
|
||||
# Experiment (data/checkpoint/directory) config
|
||||
flags.DEFINE_string("data_dir", default="",
|
||||
help="Path to tf-records directory.")
|
||||
flags.DEFINE_string("record_info_dir", default="",
|
||||
help="Path to local directory containing filenames.txt.")
|
||||
flags.DEFINE_string("corpus_info_path", default="",
|
||||
help="Path to corpus-info.json file.")
|
||||
flags.DEFINE_string("model_dir", default=None,
|
||||
help="Estimator model_dir.")
|
||||
flags.DEFINE_bool("do_train", default=True,
|
||||
help="Whether to run training.")
|
||||
flags.DEFINE_bool("do_eval", default=False,
|
||||
help="Whether to run eval on the dev set.")
|
||||
flags.DEFINE_string("eval_ckpt_path", None,
|
||||
help="Checkpoint path for do_test evaluation."
|
||||
"If set, model_dir will be ignored."
|
||||
"If unset, will use the latest ckpt in model_dir.")
|
||||
flags.DEFINE_string("warm_start_path", None,
|
||||
help="Checkpoint path for warm start."
|
||||
"If set, will clear Adam states."
|
||||
"Note that the new model_dir should be different"
|
||||
" from warm_start_path.")
|
||||
|
||||
# Optimization config
|
||||
flags.DEFINE_float("learning_rate", default=2.5e-4,
|
||||
help="Maximum learning rate.")
|
||||
flags.DEFINE_float("clip", default=0.25,
|
||||
help="Gradient clipping value.")
|
||||
# for cosine decay
|
||||
flags.DEFINE_float("min_lr_ratio", default=0.004,
|
||||
help="Minimum ratio learning rate.")
|
||||
flags.DEFINE_integer("warmup_steps", default=0,
|
||||
help="Number of steps for linear lr warmup.")
|
||||
|
||||
# Training config
|
||||
flags.DEFINE_integer("train_batch_size", default=60,
|
||||
help="Size of train batch.")
|
||||
flags.DEFINE_integer("eval_batch_size", default=60,
|
||||
help="Size of valid batch.")
|
||||
flags.DEFINE_integer("train_steps", default=100000,
|
||||
help="Total number of training steps.")
|
||||
flags.DEFINE_integer("iterations", default=500,
|
||||
help="Number of iterations per repeat loop.")
|
||||
flags.DEFINE_integer("save_steps", default=10000,
|
||||
help="number of steps for model checkpointing.")
|
||||
|
||||
# Evaluation config
|
||||
flags.DEFINE_bool("do_test", default=False,
|
||||
help="Run on the test set.")
|
||||
flags.DEFINE_integer("max_eval_batch", default=-1,
|
||||
help="Set -1 to turn off. Only used in test mode.")
|
||||
flags.DEFINE_bool("do_eval_only", default=False,
|
||||
help="Run evaluation only.")
|
||||
flags.DEFINE_integer("start_eval_steps", default=10000,
|
||||
help="Which checkpoint to start with in `do_eval_only` mode.")
|
||||
flags.DEFINE_string("eval_split", "valid",
|
||||
help="Which data split to evaluate.")
|
||||
|
||||
# Model config
|
||||
flags.DEFINE_integer("tgt_len", default=70,
|
||||
help="Number of steps to predict")
|
||||
flags.DEFINE_integer("mem_len", default=70,
|
||||
help="Number of steps to cache")
|
||||
flags.DEFINE_bool("same_length", default=False,
|
||||
help="Same length attention")
|
||||
flags.DEFINE_integer("clamp_len", default=-1,
|
||||
help="Clamp length")
|
||||
|
||||
flags.DEFINE_integer("n_layer", default=6,
|
||||
help="Number of layers.")
|
||||
flags.DEFINE_integer("d_model", default=500,
|
||||
help="Dimension of the model.")
|
||||
flags.DEFINE_integer("d_embed", default=500,
|
||||
help="Dimension of the embeddings.")
|
||||
flags.DEFINE_integer("n_head", default=10,
|
||||
help="Number of attention heads.")
|
||||
flags.DEFINE_integer("d_head", default=50,
|
||||
help="Dimension of each attention head.")
|
||||
flags.DEFINE_integer("d_inner", default=1000,
|
||||
help="Dimension of inner hidden size in positionwise feed-forward.")
|
||||
flags.DEFINE_float("dropout", default=0.1,
|
||||
help="Dropout rate.")
|
||||
flags.DEFINE_float("dropatt", default=0.1,
|
||||
help="Attention dropout rate.")
|
||||
flags.DEFINE_bool("untie_r", default=False,
|
||||
help="untie r_w_bias and r_r_bias")
|
||||
|
||||
# Adaptive Softmax / Embedding
|
||||
flags.DEFINE_bool("tie_weight", default=True,
|
||||
help="Tie embedding and softmax weight.")
|
||||
flags.DEFINE_integer("div_val", default=1,
|
||||
help="Divide the embedding size by this val for each bin")
|
||||
flags.DEFINE_bool("proj_share_all_but_first", default=False,
|
||||
help="True to share all but first projs, False not to share.")
|
||||
flags.DEFINE_bool("proj_same_dim", default=True,
|
||||
help="Project the bin with the same dimension.")
|
||||
|
||||
# Parameter initialization
|
||||
flags.DEFINE_enum("init", default="normal",
|
||||
enum_values=["normal", "uniform"],
|
||||
help="Initialization method.")
|
||||
flags.DEFINE_float("init_std", default=0.02,
|
||||
help="Initialization std when init is normal.")
|
||||
flags.DEFINE_float("proj_init_std", default=0.01,
|
||||
help="Initialization std for embedding projection.")
|
||||
flags.DEFINE_float("init_range", default=0.1,
|
||||
help="Initialization std when init is uniform.")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def get_model_fn(n_token, cutoffs):
|
||||
def model_fn(inp, tgt, mems, is_training):
|
||||
inp = tf.transpose(inp, [1, 0])
|
||||
tgt = tf.transpose(tgt, [1, 0])
|
||||
|
||||
if FLAGS.init == "uniform":
|
||||
initializer = tf.initializers.random_uniform(
|
||||
minval=-FLAGS.init_range,
|
||||
maxval=FLAGS.init_range,
|
||||
seed=None)
|
||||
elif FLAGS.init == "normal":
|
||||
initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.init_std,
|
||||
seed=None)
|
||||
proj_initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.proj_init_std,
|
||||
seed=None)
|
||||
|
||||
tie_projs = [False for _ in range(len(cutoffs) + 1)]
|
||||
if FLAGS.proj_share_all_but_first:
|
||||
for i in range(1, len(tie_projs)):
|
||||
tie_projs[i] = True
|
||||
|
||||
loss, new_mems = model.transformer(
|
||||
dec_inp=inp,
|
||||
target=tgt,
|
||||
mems=mems,
|
||||
n_token=n_token,
|
||||
n_layer=FLAGS.n_layer,
|
||||
d_model=FLAGS.d_model,
|
||||
d_embed=FLAGS.d_embed,
|
||||
n_head=FLAGS.n_head,
|
||||
d_head=FLAGS.d_head,
|
||||
d_inner=FLAGS.d_inner,
|
||||
dropout=FLAGS.dropout,
|
||||
dropatt=FLAGS.dropatt,
|
||||
initializer=initializer,
|
||||
proj_initializer=proj_initializer,
|
||||
is_training=is_training,
|
||||
mem_len=FLAGS.mem_len,
|
||||
cutoffs=cutoffs,
|
||||
div_val=FLAGS.div_val,
|
||||
tie_projs=tie_projs,
|
||||
input_perms=None,
|
||||
target_perms=None,
|
||||
head_target=None,
|
||||
same_length=FLAGS.same_length,
|
||||
clamp_len=FLAGS.clamp_len,
|
||||
use_tpu=False,
|
||||
untie_r=FLAGS.untie_r,
|
||||
proj_same_dim=FLAGS.proj_same_dim)
|
||||
|
||||
# number of parameters
|
||||
num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
|
||||
tf.logging.info('#params: {}'.format(num_params))
|
||||
|
||||
# format_str = '{{:<{0}s}}\t{{}}'.format(
|
||||
# max([len(v.name) for v in tf.trainable_variables()]))
|
||||
# for v in tf.trainable_variables():
|
||||
# tf.logging.info(format_str.format(v.name, v.get_shape()))
|
||||
|
||||
if is_training:
|
||||
all_vars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, all_vars)
|
||||
grads_and_vars = list(zip(grads, all_vars))
|
||||
|
||||
return loss, new_mems, grads_and_vars
|
||||
else:
|
||||
return loss, new_mems
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems):
|
||||
model_fn = get_model_fn(
|
||||
n_token=n_token,
|
||||
cutoffs=cutoffs)
|
||||
|
||||
model_ret = model_fn(
|
||||
inp=inp,
|
||||
tgt=tgt,
|
||||
mems=mems,
|
||||
is_training=is_training)
|
||||
|
||||
return model_ret
|
||||
|
||||
|
||||
def train(n_token, cutoffs, ps_device):
|
||||
##### Get input function and model function
|
||||
train_input_fn, train_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split="train",
|
||||
per_host_bsz=FLAGS.train_batch_size,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=1,
|
||||
use_tpu=False)
|
||||
|
||||
tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
|
||||
|
||||
##### Create computational graph
|
||||
train_set = train_input_fn({
|
||||
"batch_size": FLAGS.train_batch_size,
|
||||
"data_dir": FLAGS.data_dir})
|
||||
|
||||
input_feed, label_feed = train_set.make_one_shot_iterator().get_next()
|
||||
|
||||
inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
|
||||
labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)
|
||||
|
||||
per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host
|
||||
|
||||
tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
|
||||
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
reuse = True if i > 0 else None
|
||||
with tf.device(assign_to_gpu(i, ps_device)), \
|
||||
tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
|
||||
|
||||
mems_i = [tf.placeholder(tf.float32,
|
||||
[FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
|
||||
for _ in range(FLAGS.n_layer)]
|
||||
|
||||
loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
|
||||
n_token=n_token,
|
||||
cutoffs=cutoffs,
|
||||
is_training=True,
|
||||
inp=inputs[i],
|
||||
tgt=labels[i],
|
||||
mems=mems_i)
|
||||
|
||||
tower_mems.append(mems_i)
|
||||
tower_losses.append(loss_i)
|
||||
tower_new_mems.append(new_mems_i)
|
||||
tower_grads_and_vars.append(grads_and_vars_i)
|
||||
|
||||
## average losses and gradients across towers
|
||||
if len(tower_losses) > 1:
|
||||
loss = tf.add_n(tower_losses) / len(tower_losses)
|
||||
grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
|
||||
else:
|
||||
loss = tower_losses[0]
|
||||
grads_and_vars = tower_grads_and_vars[0]
|
||||
grads, all_vars = zip(*grads_and_vars)
|
||||
|
||||
## clip gradient
|
||||
clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
|
||||
grads_and_vars = list(zip(clipped, all_vars))
|
||||
|
||||
## configure the optimizer
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
# warmup stage: increase the learning rate linearly
|
||||
if FLAGS.warmup_steps > 0:
|
||||
warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
|
||||
* FLAGS.learning_rate
|
||||
else:
|
||||
warmup_lr = 0.0
|
||||
|
||||
# decay stage: decay the learning rate using the cosine schedule
|
||||
decay_lr = tf.train.cosine_decay(
|
||||
FLAGS.learning_rate,
|
||||
global_step=global_step-FLAGS.warmup_steps,
|
||||
decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
|
||||
alpha=FLAGS.min_lr_ratio)
|
||||
|
||||
# choose warmup or decay
|
||||
learning_rate = tf.where(global_step < FLAGS.warmup_steps,
|
||||
warmup_lr, decay_lr)
|
||||
|
||||
# get the train op
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
|
||||
train_op = optimizer.apply_gradients(grads_and_vars, global_step)
|
||||
|
||||
##### Training loop
|
||||
tower_mems_np = [
|
||||
[np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
|
||||
for layer in range(FLAGS.n_layer)]
|
||||
for core in range(FLAGS.num_core_per_host)
|
||||
]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.warm_start_path is not None:
|
||||
tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
|
||||
saver.restore(sess, FLAGS.warm_start_path)
|
||||
|
||||
fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
|
||||
|
||||
total_loss, prev_step = 0., -1
|
||||
while True:
|
||||
feed_dict = {}
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
|
||||
feed_dict[m] = m_np
|
||||
|
||||
fetched = sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
loss_np, tower_mems_np, curr_step = fetched[:3]
|
||||
total_loss += loss_np
|
||||
|
||||
if curr_step > 0 and curr_step % FLAGS.iterations == 0:
|
||||
curr_loss = total_loss / (curr_step - prev_step)
|
||||
tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
|
||||
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
|
||||
curr_step, fetched[-3], fetched[-2],
|
||||
curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
|
||||
total_loss, prev_step = 0., curr_step
|
||||
|
||||
if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
|
||||
save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
|
||||
saver.save(sess, save_path)
|
||||
tf.logging.info("Model saved in path: {}".format(save_path))
|
||||
|
||||
if curr_step == FLAGS.train_steps:
|
||||
break
|
||||
|
||||
|
||||
def evaluate(n_token, cutoffs, ps_device):
|
||||
##### Get input function and model function
|
||||
eval_input_fn, eval_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split=FLAGS.eval_split,
|
||||
per_host_bsz=FLAGS.eval_batch_size,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=1,
|
||||
use_tpu=False)
|
||||
|
||||
num_batch = eval_record_info["num_batch"]
|
||||
if FLAGS.max_eval_batch > 0:
|
||||
num_batch = FLAGS.max_eval_batch
|
||||
tf.logging.info("num of batches {}".format(num_batch))
|
||||
|
||||
##### Create computational graph
|
||||
eval_set = eval_input_fn({
|
||||
"batch_size": FLAGS.eval_batch_size,
|
||||
"data_dir": FLAGS.data_dir})
|
||||
|
||||
input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()
|
||||
|
||||
inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
|
||||
labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)
|
||||
|
||||
per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host
|
||||
tower_mems, tower_losses, tower_new_mems = [], [], []
|
||||
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
with tf.device(assign_to_gpu(i, ps_device)), \
|
||||
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
|
||||
|
||||
mems_i = [tf.placeholder(tf.float32,
|
||||
[FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
|
||||
for _ in range(FLAGS.n_layer)]
|
||||
|
||||
loss_i, new_mems_i = single_core_graph(
|
||||
n_token=n_token,
|
||||
cutoffs=cutoffs,
|
||||
is_training=False,
|
||||
inp=inputs[i],
|
||||
tgt=labels[i],
|
||||
mems=mems_i)
|
||||
|
||||
tower_mems.append(mems_i)
|
||||
tower_losses.append(loss_i)
|
||||
tower_new_mems.append(new_mems_i)
|
||||
|
||||
## sum losses across towers
|
||||
if len(tower_losses) > 1:
|
||||
loss = tf.add_n(tower_losses) / len(tower_losses)
|
||||
else:
|
||||
loss = tower_losses[0]
|
||||
|
||||
##### Evaluation loop
|
||||
tower_mems_np = [
|
||||
[np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
|
||||
for layer in range(FLAGS.n_layer)]
|
||||
for core in range(FLAGS.num_core_per_host)
|
||||
]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.eval_ckpt_path is None:
|
||||
eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
|
||||
else:
|
||||
eval_ckpt_path = FLAGS.eval_ckpt_path
|
||||
tf.logging.info("Evaluate {}".format(eval_ckpt_path))
|
||||
saver.restore(sess, eval_ckpt_path)
|
||||
|
||||
fetches = [loss, tower_new_mems, tf.size(label_feed)]
|
||||
|
||||
format_str = " >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
|
||||
len(str(num_batch)))
|
||||
|
||||
total_loss, total_cnt = 0, 0
|
||||
for step in range(num_batch):
|
||||
if step % (num_batch // 10) == 0:
|
||||
tf.logging.info(format_str.format(step, num_batch))
|
||||
|
||||
feed_dict = {}
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
|
||||
feed_dict[m] = m_np
|
||||
|
||||
fetched = sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
loss_np, tower_mems_np, cnt_np = fetched[:3]
|
||||
total_loss += loss_np * cnt_np
|
||||
total_cnt += cnt_np
|
||||
|
||||
avg_loss = total_loss / total_cnt
|
||||
tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
|
||||
avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv # Unused
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
# Get corpus info
|
||||
corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
|
||||
n_token = corpus_info["vocab_size"]
|
||||
cutoffs = corpus_info["cutoffs"][1:-1]
|
||||
tf.logging.info("n_token {}".format(n_token))
|
||||
|
||||
if FLAGS.do_train:
|
||||
train(n_token, cutoffs, "/gpu:0")
|
||||
if FLAGS.do_eval:
|
||||
evaluate(n_token, cutoffs, "/gpu:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
Reference in a new issue