586 lines
20 KiB
Python
586 lines
20 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
import os
|
|
from functools import partial
|
|
|
|
from collections import Counter, OrderedDict
|
|
import pickle
|
|
import json
|
|
import multiprocessing as mp
|
|
|
|
import numpy as np
|
|
|
|
from absl import flags
|
|
import tensorflow as tf
|
|
from vocabulary import Vocab
|
|
|
|
from tensorflow.gfile import Exists as exists
|
|
from tensorflow.gfile import MakeDirs as makedirs
|
|
from tensorflow.gfile import Glob as glob
|
|
|
|
|
|
def _preprocess(shard, train, vocab, save_dir, cutoffs, bin_sizes, bsz, tgt_len,
|
|
num_core_per_host, use_tpu, num_shuffle):
|
|
file_names = []
|
|
num_batch = 0
|
|
|
|
path = train[shard]
|
|
data_shard = vocab.encode_file(path, ordered=False, add_double_eos=True)
|
|
|
|
for shuffle in range(num_shuffle):
|
|
basename = "train-{:03d}-{:02d}".format(shard, shuffle)
|
|
print("Processing shard {} shuffle {}".format(shard, shuffle))
|
|
|
|
np.random.shuffle(data_shard)
|
|
file_name, num_batch_shuffle = create_ordered_tfrecords(
|
|
save_dir, basename, np.concatenate(data_shard), bsz, tgt_len,
|
|
num_core_per_host, cutoffs, bin_sizes, use_tpu=use_tpu)
|
|
file_names.append(file_name)
|
|
num_batch += num_batch_shuffle
|
|
|
|
return file_names, num_batch
|
|
|
|
|
|
class Corpus(object):
|
|
def __init__(self, path, dataset, *args, **kwargs):
|
|
self.dataset = dataset
|
|
self.vocab = Vocab(*args, **kwargs)
|
|
|
|
if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
|
|
self.vocab.count_file(os.path.join(path, "train.txt"))
|
|
self.vocab.count_file(os.path.join(path, "valid.txt"))
|
|
self.vocab.count_file(os.path.join(path, "test.txt"))
|
|
elif self.dataset == "wt103":
|
|
self.vocab.count_file(os.path.join(path, "train.txt"))
|
|
elif self.dataset == "lm1b":
|
|
train_path_pattern = os.path.join(
|
|
path, "1-billion-word-language-modeling-benchmark-r13output",
|
|
"training-monolingual.tokenized.shuffled", "news.en-*")
|
|
train_paths = glob(train_path_pattern)
|
|
|
|
# the vocab will load from file when build_vocab() is called
|
|
# for train_path in sorted(train_paths):
|
|
# self.vocab.count_file(train_path, verbose=True)
|
|
|
|
self.vocab.build_vocab()
|
|
|
|
if self.dataset in ["ptb", "wt2", "wt103"]:
|
|
self.train = self.vocab.encode_file(
|
|
os.path.join(path, "train.txt"), ordered=True)
|
|
self.valid = self.vocab.encode_file(
|
|
os.path.join(path, "valid.txt"), ordered=True)
|
|
self.test = self.vocab.encode_file(
|
|
os.path.join(path, "test.txt"), ordered=True)
|
|
elif self.dataset in ["enwik8", "text8"]:
|
|
self.train = self.vocab.encode_file(
|
|
os.path.join(path, "train.txt"), ordered=True, add_eos=False)
|
|
self.valid = self.vocab.encode_file(
|
|
os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
|
|
self.test = self.vocab.encode_file(
|
|
os.path.join(path, "test.txt"), ordered=True, add_eos=False)
|
|
elif self.dataset == "lm1b":
|
|
self.train = train_paths
|
|
valid_path = os.path.join(path, "valid.txt")
|
|
test_path = valid_path
|
|
self.valid = self.vocab.encode_file(
|
|
valid_path, ordered=True, add_double_eos=True)
|
|
self.test = self.vocab.encode_file(
|
|
test_path, ordered=True, add_double_eos=True)
|
|
|
|
if self.dataset == "wt103":
|
|
self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]
|
|
elif self.dataset == "lm1b":
|
|
self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)]
|
|
else:
|
|
self.cutoffs = []
|
|
|
|
|
|
def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
|
|
num_core_per_host, **kwargs):
|
|
FLAGS = kwargs.get('FLAGS')
|
|
|
|
file_names = []
|
|
use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1)
|
|
|
|
if use_tpu:
|
|
record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
|
|
split, bsz, tgt_len, num_core_per_host)
|
|
else:
|
|
record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
|
|
split, bsz, tgt_len)
|
|
|
|
record_info_path = os.path.join(save_dir, record_name)
|
|
|
|
if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
|
|
data = getattr(self, split)
|
|
bin_sizes = get_bin_sizes(
|
|
data, bsz // num_core_per_host, tgt_len, self.cutoffs)
|
|
file_name, num_batch = create_ordered_tfrecords(
|
|
save_dir, split, data, bsz, tgt_len, num_core_per_host,
|
|
self.cutoffs, bin_sizes,
|
|
num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1,
|
|
use_tpu=use_tpu)
|
|
file_names.append(file_name)
|
|
elif self.dataset == "lm1b":
|
|
bin_sizes = get_bin_sizes(
|
|
self.valid, bsz // num_core_per_host, tgt_len, self.cutoffs)
|
|
if split == "train":
|
|
np.random.seed(123456)
|
|
num_batch = 0
|
|
|
|
if FLAGS.num_procs > 1:
|
|
_preprocess_wrapper = partial(_preprocess,
|
|
train=self.train, vocab=self.vocab, save_dir=save_dir,
|
|
cutoffs=self.cutoffs, bin_sizes=bin_sizes, bsz=bsz,
|
|
tgt_len=tgt_len, num_core_per_host=num_core_per_host,
|
|
use_tpu=use_tpu, num_shuffle=FLAGS.num_shuffle)
|
|
|
|
pool = mp.Pool(processes=FLAGS.num_procs)
|
|
results = pool.map(_preprocess_wrapper, range(len(self.train)))
|
|
for res in results:
|
|
file_names.extend(res[0])
|
|
num_batch += res[1]
|
|
else:
|
|
for shard, path in enumerate(self.train):
|
|
data_shard = self.vocab.encode_file(path, ordered=False,
|
|
add_double_eos=True)
|
|
|
|
num_shuffle = FLAGS.num_shuffle
|
|
|
|
for shuffle in range(num_shuffle):
|
|
print("Processing shard {} shuffle {}".format(shard, shuffle))
|
|
basename = "train-{:03d}-{:02d}".format(shard, shuffle)
|
|
np.random.shuffle(data_shard)
|
|
file_name, num_batch_ = create_ordered_tfrecords(
|
|
save_dir, basename, np.concatenate(data_shard), bsz, tgt_len,
|
|
num_core_per_host,
|
|
self.cutoffs, bin_sizes, use_tpu=use_tpu)
|
|
file_names.append(file_name)
|
|
num_batch += num_batch_
|
|
|
|
else:
|
|
file_name, num_batch = create_ordered_tfrecords(
|
|
save_dir, split, getattr(self, split), bsz, tgt_len,
|
|
num_core_per_host,
|
|
self.cutoffs, bin_sizes, use_tpu=use_tpu)
|
|
file_names.append(file_name)
|
|
|
|
with open(record_info_path, "w") as fp:
|
|
record_info = {
|
|
"filenames": file_names,
|
|
"bin_sizes": bin_sizes,
|
|
"num_batch": num_batch
|
|
}
|
|
json.dump(record_info, fp)
|
|
|
|
|
|
def get_bin_sizes(data, batch_size, tgt_len, cutoffs, std_mult=[2.5, 2.5, 2.5]):
|
|
"""
|
|
Note: the `batch_size` here should be per-core batch size
|
|
"""
|
|
bin_sizes = []
|
|
|
|
def _nearest_to_eight(x): # so that it's faster on TPUs
|
|
y = x - x % 8
|
|
return y + 8 if x % 8 >= 4 else max(8, y)
|
|
|
|
if cutoffs:
|
|
num_batch = len(data) // batch_size // tgt_len
|
|
|
|
data = data[:batch_size * num_batch * tgt_len]
|
|
data = data.reshape(batch_size, num_batch, tgt_len)
|
|
|
|
tot = batch_size * tgt_len
|
|
for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):
|
|
mask = (data >= left) * (data < right)
|
|
percents = mask.astype(np.float64).sum(2).sum(0) / tot
|
|
mean = np.mean(percents)
|
|
std = np.std(percents)
|
|
|
|
bin_size = int(math.ceil(tgt_len * batch_size * (mean + std_mult[b] * std)))
|
|
bin_size = _nearest_to_eight(bin_size)
|
|
bin_sizes.append(bin_size)
|
|
|
|
return bin_sizes
|
|
|
|
|
|
def _int64_feature(values):
|
|
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
|
|
|
|
def _float_feature(values):
|
|
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
|
|
|
|
def batchify(data, batch_size, num_passes):
|
|
"""
|
|
if use_tpu = True: num_passes > 1
|
|
|
|
Since TPU training requires entire [bsz x tgt_len] chunks, it can discard
|
|
as many as `bsz * tgt_len` tokens in training. When `bsz` and `tgt_len` are
|
|
both large, as in the case of TPU training for Transformer-XL, the problem
|
|
may lead to detectable performance drop.
|
|
|
|
Here, we use multiple randomly shifted copies to deal with this problem.
|
|
"""
|
|
if num_passes > 1:
|
|
data_len = len(data)
|
|
double_data = np.concatenate([data, data])
|
|
data_list = []
|
|
for i in range(num_passes):
|
|
start = np.random.randint(0, data_len)
|
|
data_list.append(double_data[start:start+data_len])
|
|
data = np.concatenate(data_list)
|
|
|
|
num_step = len(data) // batch_size
|
|
data = data[:batch_size * num_step]
|
|
data = data.reshape(batch_size, num_step)
|
|
|
|
return data
|
|
|
|
|
|
def create_ordered_tfrecords(save_dir, basename, data, batch_size, tgt_len,
|
|
num_core_per_host, cutoffs=[], bin_sizes=[],
|
|
num_passes=1, use_tpu=False):
|
|
|
|
if use_tpu:
|
|
file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format(
|
|
basename, batch_size, tgt_len, num_core_per_host)
|
|
else:
|
|
file_name = "{}.bsz-{}.tlen-{}.tfrecords".format(
|
|
basename, batch_size, tgt_len)
|
|
|
|
save_path = os.path.join(save_dir, file_name)
|
|
record_writer = tf.python_io.TFRecordWriter(save_path)
|
|
|
|
batched_data = batchify(data, batch_size, num_passes)
|
|
|
|
num_batch = 0
|
|
# for t in range(0, batched_data.shape[1] - tgt_len - 1, tgt_len):
|
|
for t in range(0, batched_data.shape[1] - 1, tgt_len):
|
|
cur_tgt_len = min(batched_data.shape[1] - 1 - t, tgt_len)
|
|
# drop the remainder if use tpu
|
|
if use_tpu and cur_tgt_len < tgt_len:
|
|
break
|
|
if num_batch % 500 == 0:
|
|
print(" processing batch {}".format(num_batch))
|
|
for idx in range(batch_size):
|
|
inputs = batched_data[idx, t:t + cur_tgt_len]
|
|
labels = batched_data[idx, t + 1:t + cur_tgt_len + 1]
|
|
|
|
# features dict
|
|
feature = {
|
|
"inputs": _int64_feature(inputs),
|
|
"labels": _int64_feature(labels),
|
|
}
|
|
|
|
if len(cutoffs) > 0 and use_tpu:
|
|
# validate `bin_sizes` and `cutoffs`
|
|
assert len(cutoffs) - len(bin_sizes) == 2, \
|
|
"len(cutoffs) - len(bin_sizes) != 2"
|
|
|
|
# mask for bin 0
|
|
left, right = cutoffs[:2]
|
|
inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32)
|
|
tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32)
|
|
|
|
feature["inp_mask"] = _float_feature(inp_mask)
|
|
feature["tgt_mask"] = _float_feature(tgt_mask)
|
|
|
|
# refresh `inp_cnts` and `tgt_cnts` for each TPU core
|
|
if idx % (batch_size // num_core_per_host) == 0:
|
|
inp_cnts = [0] * len(bin_sizes)
|
|
tgt_cnts = [0] * len(bin_sizes)
|
|
|
|
head_labels = np.copy(labels)
|
|
inp_pos_per_bin, tgt_pos_per_bin = [], []
|
|
for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):
|
|
inp_pos = np.where((inputs >= left) * (inputs < right))[0]
|
|
tgt_pos = np.where((labels >= left) * (labels < right))[0]
|
|
inp_pos_per_bin.append(inp_pos)
|
|
tgt_pos_per_bin.append(tgt_pos)
|
|
|
|
head_labels[tgt_pos] = cutoffs[1] + b
|
|
|
|
feature["head_labels"] = _int64_feature(head_labels)
|
|
|
|
# permutation feature
|
|
def _add_perm_feature(feature, pos_per_bin, cnts, prefix):
|
|
for b, pos in enumerate(pos_per_bin):
|
|
idx_tuple = []
|
|
for p in pos:
|
|
if cnts[b] < bin_sizes[b]:
|
|
idx_tuple.append([p, cnts[b]])
|
|
cnts[b] += 1
|
|
else:
|
|
break
|
|
|
|
n_tup = len(idx_tuple)
|
|
tup = np.array(idx_tuple).reshape(n_tup * 2)
|
|
|
|
feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup])
|
|
feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup)
|
|
|
|
_add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp")
|
|
_add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt")
|
|
|
|
example = tf.train.Example(features=tf.train.Features(feature=feature))
|
|
record_writer.write(example.SerializeToString())
|
|
|
|
num_batch += 1
|
|
|
|
record_writer.close()
|
|
print("Done writing {}. batches: {}".format(file_name, num_batch))
|
|
|
|
return file_name, num_batch
|
|
|
|
|
|
def get_lm_corpus(data_dir, dataset):
|
|
fn = os.path.join(data_dir, "cache.pkl")
|
|
|
|
if exists(fn):
|
|
print("Loading cached dataset...")
|
|
with open(fn, "rb") as fp:
|
|
corpus = pickle.load(fp)
|
|
else:
|
|
print("Producing dataset...")
|
|
kwargs = {}
|
|
if dataset in ["wt103", "wt2"]:
|
|
kwargs["special"] = ["<eos>"]
|
|
kwargs["lower_case"] = False
|
|
elif dataset == "ptb":
|
|
kwargs["special"] = ["<eos>"]
|
|
kwargs["lower_case"] = True
|
|
elif dataset == "lm1b":
|
|
kwargs["special"] = []
|
|
kwargs["lower_case"] = False
|
|
kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt")
|
|
elif dataset in ["enwik8", "text8"]:
|
|
pass
|
|
|
|
corpus = Corpus(data_dir, dataset, **kwargs)
|
|
|
|
print("Saving dataset...")
|
|
with open(fn, "wb") as fp:
|
|
pickle.dump(corpus, fp, protocol=2)
|
|
|
|
corpus_info = {
|
|
"vocab_size" : len(corpus.vocab),
|
|
"cutoffs" : corpus.cutoffs,
|
|
"dataset" : corpus.dataset
|
|
}
|
|
with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
|
|
json.dump(corpus_info, fp)
|
|
|
|
return corpus
|
|
|
|
|
|
def main(unused_argv):
|
|
del unused_argv # Unused
|
|
|
|
corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)
|
|
|
|
save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
|
|
if not exists(save_dir):
|
|
makedirs(save_dir)
|
|
|
|
# test mode
|
|
if FLAGS.per_host_test_bsz > 0:
|
|
corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
|
|
FLAGS.tgt_len, FLAGS.num_core_per_host,
|
|
FLAGS=FLAGS)
|
|
return
|
|
|
|
for split, batch_size in zip(
|
|
["train", "valid"],
|
|
[FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):
|
|
|
|
if batch_size <= 0: continue
|
|
print("Converting {} set...".format(split))
|
|
corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
|
|
FLAGS.num_core_per_host, FLAGS=FLAGS)
|
|
|
|
|
|
def load_record_info(record_info_dir, split, per_host_bsz, tgt_len,
|
|
num_core_per_host, use_tpu):
|
|
if use_tpu:
|
|
record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
|
|
split, per_host_bsz, tgt_len, num_core_per_host)
|
|
else:
|
|
record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
|
|
split, per_host_bsz, tgt_len)
|
|
|
|
record_info_path = os.path.join(record_info_dir, record_name)
|
|
with open(record_info_path, "r") as fp:
|
|
record_info = json.load(fp)
|
|
|
|
return record_info
|
|
|
|
def get_input_fn(record_info_dir, split, per_host_bsz, tgt_len,
|
|
num_core_per_host, num_hosts=1, use_tpu=False):
|
|
"""Creates input function."""
|
|
record_info = load_record_info(record_info_dir, split, per_host_bsz, tgt_len,
|
|
num_core_per_host, use_tpu=use_tpu)
|
|
|
|
file_names = record_info["filenames"]
|
|
bin_sizes = record_info["bin_sizes"]
|
|
num_batch = record_info["num_batch"]
|
|
|
|
tf.logging.info("[{}] File names {}".format(split, file_names))
|
|
|
|
def input_fn(params):
|
|
# per-core batch size
|
|
per_core_bsz = params["batch_size"]
|
|
|
|
# data_dir could be a remote path, e.g., a google storage url
|
|
data_dir = params["data_dir"]
|
|
|
|
def parser(record):
|
|
# preprocess "inp_perm" and "tgt_perm"
|
|
def _process_perm_feature(example, prefix):
|
|
for b in range(len(bin_sizes)):
|
|
cnt = example.pop("{}_cnt_{}".format(prefix, b))[0]
|
|
tup = example.pop("{}_tup_{}".format(prefix, b))
|
|
|
|
tup = tf.reshape(
|
|
tf.sparse_tensor_to_dense(tup),
|
|
shape=[cnt, 2])
|
|
|
|
# tf.float32
|
|
perm = tf.sparse_to_dense(
|
|
sparse_indices=tup,
|
|
output_shape=[tgt_len, bin_sizes[b]],
|
|
sparse_values=1.0,
|
|
default_value=0.0)
|
|
|
|
example["{}_perm_{}".format(prefix, b)] = perm
|
|
|
|
# whether allow the last batch with a potentially shorter length
|
|
if use_tpu:
|
|
record_spec = {
|
|
"inputs": tf.FixedLenFeature([tgt_len], tf.int64),
|
|
"labels": tf.FixedLenFeature([tgt_len], tf.int64),
|
|
}
|
|
else:
|
|
record_spec = {
|
|
"inputs": tf.VarLenFeature(tf.int64),
|
|
"labels": tf.VarLenFeature(tf.int64),
|
|
}
|
|
|
|
# permutation related features
|
|
if bin_sizes and use_tpu:
|
|
# tf.float32
|
|
record_spec["inp_mask"] = tf.FixedLenFeature([tgt_len], tf.float32)
|
|
record_spec["tgt_mask"] = tf.FixedLenFeature([tgt_len], tf.float32)
|
|
|
|
record_spec["head_labels"] = tf.FixedLenFeature([tgt_len], tf.int64)
|
|
|
|
for b in range(len(bin_sizes)):
|
|
record_spec["inp_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64)
|
|
record_spec["inp_tup_{}".format(b)] = tf.VarLenFeature(tf.int64)
|
|
record_spec["tgt_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64)
|
|
record_spec["tgt_tup_{}".format(b)] = tf.VarLenFeature(tf.int64)
|
|
|
|
# retrieve serialized example
|
|
example = tf.parse_single_example(
|
|
serialized=record,
|
|
features=record_spec)
|
|
|
|
# transform permutation tuples to permutation matrices
|
|
if bin_sizes and use_tpu:
|
|
_process_perm_feature(example, "inp")
|
|
_process_perm_feature(example, "tgt")
|
|
|
|
# cast int64 into int32
|
|
# cast sparse to dense
|
|
for key in list(example.keys()):
|
|
val = example[key]
|
|
if tf.keras.backend.is_sparse(val):
|
|
val = tf.sparse.to_dense(val)
|
|
if val.dtype == tf.int64:
|
|
val = tf.to_int32(val)
|
|
example[key] = val
|
|
|
|
if use_tpu:
|
|
return example
|
|
else:
|
|
return example["inputs"], example["labels"]
|
|
|
|
file_paths = []
|
|
for file_name in file_names:
|
|
file_path = os.path.join(data_dir, file_name)
|
|
file_paths.append(file_path)
|
|
|
|
if split == "train":
|
|
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
|
|
if len(file_paths) > 1:
|
|
dataset = dataset.shuffle(len(file_paths)).repeat()
|
|
dataset = tf.data.TFRecordDataset(dataset)
|
|
elif num_hosts > 1:
|
|
host_id = params["context"].current_host
|
|
# drop the remaining batches
|
|
num_batch_per_host = num_batch // num_hosts
|
|
|
|
my_start_sample_id = (host_id * num_batch_per_host * num_core_per_host *
|
|
per_core_bsz)
|
|
my_sample_num = num_batch_per_host * num_core_per_host * per_core_bsz
|
|
dataset = tf.data.TFRecordDataset(dataset).skip(
|
|
my_start_sample_id).take(my_sample_num)
|
|
else:
|
|
dataset = tf.data.TFRecordDataset(dataset)
|
|
|
|
dataset = dataset.map(parser).cache().repeat()
|
|
dataset = dataset.batch(per_core_bsz, drop_remainder=True)
|
|
dataset = dataset.prefetch(num_core_per_host * per_core_bsz)
|
|
else:
|
|
# do not shuffle, repeat or cache in evaluation
|
|
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
|
|
dataset = tf.data.TFRecordDataset(dataset)
|
|
dataset = dataset.map(parser)
|
|
dataset = dataset.batch(per_core_bsz, drop_remainder=True)
|
|
|
|
return dataset
|
|
|
|
if split == "train" and num_hosts > 1:
|
|
record_info["num_batch"] = num_batch // num_hosts
|
|
|
|
return input_fn, record_info
|
|
|
|
def get_corpus_info(corpus_info_path):
|
|
with open(corpus_info_path, "r") as fp:
|
|
corpus_info = json.load(fp)
|
|
return corpus_info
|
|
|
|
if __name__ == "__main__":
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_string("data_dir", None,
|
|
help="Location of the data corpus")
|
|
flags.DEFINE_enum("dataset", "wt103",
|
|
["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"],
|
|
help="Dataset name.")
|
|
flags.DEFINE_integer("per_host_train_bsz", 60,
|
|
help="train batch size each host")
|
|
flags.DEFINE_integer("per_host_valid_bsz", 60,
|
|
help="valid batch size each host")
|
|
flags.DEFINE_integer("per_host_test_bsz", 0,
|
|
help="If > 0, enter test mode and process test set only."
|
|
"Otherwise, process train and dev sets only.")
|
|
flags.DEFINE_integer("tgt_len", 70,
|
|
help="number of tokens to predict")
|
|
flags.DEFINE_integer("max_batch", -1,
|
|
help="run in debug mode")
|
|
flags.DEFINE_integer("num_core_per_host", 8,
|
|
help="8 for TPU v2.")
|
|
flags.DEFINE_bool("debug", default=False,
|
|
help="Process only the first batch without shuffle for lm1b.")
|
|
flags.DEFINE_integer("num_procs", 1,
|
|
help="number of processes")
|
|
flags.DEFINE_integer("num_passes", 10,
|
|
help="number of passes when use_tpu=True")
|
|
flags.DEFINE_integer("num_shuffle", 4,
|
|
help="number of shuffles for lm1b")
|
|
flags.DEFINE_bool("use_tpu", True,
|
|
help="use tpu")
|
|
|
|
tf.app.run(main)
|