This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../transformer-xl/pytorch/train.py
2025-11-25 20:20:08 +01:00

562 lines
24 KiB
Python

# coding: utf-8
import argparse
import time
import math
import os, sys
import itertools
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.exp_utils import create_exp_dir
from utils.data_parallel import BalancedDataParallel
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
parser.add_argument('--n_layer', type=int, default=12,
help='number of total layers')
parser.add_argument('--n_head', type=int, default=10,
help='number of heads')
parser.add_argument('--d_head', type=int, default=50,
help='head dimension')
parser.add_argument('--d_embed', type=int, default=-1,
help='embedding dimension')
parser.add_argument('--d_model', type=int, default=500,
help='model dimension')
parser.add_argument('--d_inner', type=int, default=1000,
help='inner dimension in FF')
parser.add_argument('--dropout', type=float, default=0.0,
help='global dropout rate')
parser.add_argument('--dropatt', type=float, default=0.0,
help='attention probability dropout rate')
parser.add_argument('--init', default='normal', type=str,
help='parameter initializer to use.')
parser.add_argument('--emb_init', default='normal', type=str,
help='parameter initializer to use.')
parser.add_argument('--init_range', type=float, default=0.1,
help='parameters initialized by U(-init_range, init_range)')
parser.add_argument('--emb_init_range', type=float, default=0.01,
help='parameters initialized by U(-init_range, init_range)')
parser.add_argument('--init_std', type=float, default=0.02,
help='parameters initialized by N(0, init_std)')
parser.add_argument('--proj_init_std', type=float, default=0.01,
help='parameters initialized by N(0, init_std)')
parser.add_argument('--optim', default='adam', type=str,
choices=['adam', 'sgd', 'adagrad'],
help='optimizer to use.')
parser.add_argument('--lr', type=float, default=0.00025,
help='initial learning rate (0.00025|5 for adam|sgd)')
parser.add_argument('--mom', type=float, default=0.0,
help='momentum for sgd')
parser.add_argument('--scheduler', default='cosine', type=str,
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
help='lr scheduler to use.')
parser.add_argument('--warmup_step', type=int, default=0,
help='upper epoch limit')
parser.add_argument('--decay_rate', type=float, default=0.5,
help='decay factor when ReduceLROnPlateau is used')
parser.add_argument('--lr_min', type=float, default=0.0,
help='minimum learning rate during annealing')
parser.add_argument('--clip', type=float, default=0.25,
help='gradient clipping')
parser.add_argument('--clip_nonemb', action='store_true',
help='only clip the gradient of non-embedding params')
parser.add_argument('--max_step', type=int, default=100000,
help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=60,
help='batch size')
parser.add_argument('--batch_chunk', type=int, default=1,
help='split batch into chunks to save memory')
parser.add_argument('--tgt_len', type=int, default=70,
help='number of tokens to predict')
parser.add_argument('--eval_tgt_len', type=int, default=50,
help='number of tokens to predict for evaluation')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,
help='length of the retained previous heads')
parser.add_argument('--not_tied', action='store_true',
help='do not tie the word embedding and softmax weights')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--adaptive', action='store_true',
help='use adaptive softmax')
parser.add_argument('--div_val', type=int, default=1,
help='divident value for adapative input and softmax')
parser.add_argument('--pre_lnorm', action='store_true',
help='apply LayerNorm to the input instead of the output')
parser.add_argument('--varlen', action='store_true',
help='use variable length')
parser.add_argument('--multi_gpu', action='store_true',
help='use multiple GPU')
parser.add_argument('--log-interval', type=int, default=200,
help='report interval')
parser.add_argument('--eval-interval', type=int, default=4000,
help='evaluation interval')
parser.add_argument('--work_dir', default='LM-TFM', type=str,
help='experiment directory.')
parser.add_argument('--restart', action='store_true',
help='restart training from the saved checkpoint')
parser.add_argument('--restart_dir', type=str, default='',
help='restart dir')
parser.add_argument('--debug', action='store_true',
help='run in debug mode (do not create exp dir)')
parser.add_argument('--same_length', action='store_true',
help='use the same attn length for all tokens')
parser.add_argument('--attn_type', type=int, default=0,
help='attention type. 0 for ours, 1 for Shaw et al,'
'2 for Vaswani et al, 3 for Al Rfou et al.')
parser.add_argument('--clamp_len', type=int, default=-1,
help='use the same pos embeddings after clamp_len')
parser.add_argument('--eta_min', type=float, default=0.0,
help='min learning rate for cosine scheduler')
parser.add_argument('--gpu0_bsz', type=int, default=-1,
help='batch size on gpu 0')
parser.add_argument('--max_eval_steps', type=int, default=-1,
help='max eval steps')
parser.add_argument('--sample_softmax', type=int, default=-1,
help='number of samples in sampled softmax')
parser.add_argument('--patience', type=int, default=0,
help='patience')
parser.add_argument('--finetune_v2', action='store_true',
help='finetune v2')
parser.add_argument('--finetune_v3', action='store_true',
help='finetune v3')
parser.add_argument('--fp16', action='store_true',
help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
parser.add_argument('--static-loss-scale', type=float, default=1,
help='Static loss scale, positive power of 2 values can '
'improve fp16 convergence.')
parser.add_argument('--dynamic-loss-scale', action='store_true',
help='Use dynamic loss scaling. If supplied, this argument'
' supersedes --static-loss-scale.')
args = parser.parse_args()
args.tied = not args.not_tied
if args.d_embed < 0:
args.d_embed = args.d_model
assert args.ext_len >= 0, 'extended context length must be non-negative'
assert args.batch_size % args.batch_chunk == 0
args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
logging = create_exp_dir(args.work_dir,
scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug)
# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
if not args.cuda:
print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda')
else:
torch.cuda.manual_seed_all(args.seed)
# Validate `--fp16` option
if args.fp16:
if not args.cuda:
print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
args.fp16 = False
else:
try:
from apex.fp16_utils import FP16_Optimizer
except:
print('WARNING: apex not installed, ignoring --fp16 option')
args.fp16 = False
device = torch.device('cuda' if args.cuda else 'cpu')
###############################################################################
# Load data
###############################################################################
corpus = get_lm_corpus(args.data, args.dataset)
ntokens = len(corpus.vocab)
args.n_token = ntokens
eval_batch_size = 10
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
device=device, ext_len=args.ext_len)
# adaptive softmax / embedding
cutoffs, tie_projs = [], [False]
if args.adaptive:
assert args.dataset in ['wt103', 'lm1b']
if args.dataset == 'wt103':
cutoffs = [20000, 40000, 200000]
tie_projs += [True] * len(cutoffs)
elif args.dataset == 'lm1b':
cutoffs = [60000, 100000, 640000]
tie_projs += [False] * len(cutoffs)
###############################################################################
# Build the model
###############################################################################
def init_weight(weight):
if args.init == 'uniform':
nn.init.uniform_(weight, -args.init_range, args.init_range)
elif args.init == 'normal':
nn.init.normal_(weight, 0.0, args.init_std)
def init_bias(bias):
nn.init.constant_(bias, 0.0)
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
if hasattr(m, 'weight') and m.weight is not None:
init_weight(m.weight)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('AdaptiveEmbedding') != -1:
if hasattr(m, 'emb_projs'):
for i in range(len(m.emb_projs)):
if m.emb_projs[i] is not None:
nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
elif classname.find('Embedding') != -1:
if hasattr(m, 'weight'):
init_weight(m.weight)
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
init_weight(m.cluster_weight)
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
init_bias(m.cluster_bias)
if hasattr(m, 'out_projs'):
for i in range(len(m.out_projs)):
if m.out_projs[i] is not None:
nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
elif classname.find('LayerNorm') != -1:
if hasattr(m, 'weight'):
nn.init.normal_(m.weight, 1.0, args.init_std)
if hasattr(m, 'bias') and m.bias is not None:
init_bias(m.bias)
elif classname.find('TransformerLM') != -1:
if hasattr(m, 'r_emb'):
init_weight(m.r_emb)
if hasattr(m, 'r_w_bias'):
init_weight(m.r_w_bias)
if hasattr(m, 'r_r_bias'):
init_weight(m.r_r_bias)
if hasattr(m, 'r_bias'):
init_bias(m.r_bias)
def update_dropout(m):
classname = m.__class__.__name__
if classname.find('Dropout') != -1:
if hasattr(m, 'p'):
m.p = args.dropout
def update_dropatt(m):
if hasattr(m, 'dropatt'):
m.dropatt.p = args.dropatt
if args.restart:
with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
model = torch.load(f)
if not args.fp16:
model = model.float()
model.apply(update_dropout)
model.apply(update_dropatt)
else:
model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
args.d_head, args.d_inner, args.dropout, args.dropatt,
tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
same_length=args.same_length, attn_type=args.attn_type,
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
model.apply(weights_init)
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
args.n_all_param = sum([p.nelement() for p in model.parameters()])
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
if args.fp16:
model = model.half()
if args.multi_gpu:
model = model.to(device)
if args.gpu0_bsz >= 0:
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
model, dim=1).to(device)
else:
para_model = nn.DataParallel(model, dim=1).to(device)
else:
para_model = model.to(device)
#### optimizer
if args.optim.lower() == 'sgd':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
else:
optimizer = optim.SGD(model.parameters(), lr=args.lr,
momentum=args.mom)
elif args.optim.lower() == 'adam':
if args.sample_softmax > 0:
dense_params, sparse_params = [], []
for param in model.parameters():
if param.size() == model.word_emb.weight.size():
sparse_params.append(param)
else:
dense_params.append(param)
optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
optimizer = optim.Adam(dense_params, lr=args.lr)
else:
optimizer = optim.Adam(model.parameters(), lr=args.lr)
elif args.optim.lower() == 'adagrad':
optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
#### scheduler
if args.scheduler == 'cosine':
# here we do not set eta_min to lr_min to be backward compatible
# because in previous versions eta_min is default to 0
# rather than the default value of lr_min 1e-6
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
args.max_step, eta_min=args.eta_min) # should use eta_min arg
elif args.scheduler == 'inv_sqrt':
# originally used for Transformer (in Attention is all you need)
def lr_lambda(step):
# return a multiplier instead of a learning rate
if step == 0 and args.warmup_step == 0:
return 1.
else:
return 1. / (step ** 0.5) if step > args.warmup_step \
else step / (args.warmup_step ** 1.5)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
elif args.scheduler == 'dev_perf':
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
if args.sample_softmax > 0:
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
elif args.scheduler == 'constant':
pass
if args.cuda and args.fp16:
# If args.dynamic_loss_scale is False, static_loss_scale will be used.
# If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
optimizer = FP16_Optimizer(optimizer,
static_loss_scale = args.static_loss_scale,
dynamic_loss_scale = args.dynamic_loss_scale,
dynamic_loss_args = {'init_scale': 2 ** 16})
if args.restart:
if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
opt_state_dict = torch.load(f)
optimizer.load_state_dict(opt_state_dict)
else:
print('Optimizer was not saved. Start from scratch.')
logging('=' * 100)
for k, v in args.__dict__.items():
logging(' - {} : {}'.format(k, v))
logging('=' * 100)
logging('#params = {}'.format(args.n_all_param))
logging('#non emb params = {}'.format(args.n_nonemb_param))
###############################################################################
# Training code
###############################################################################
def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()
# If the model does not use memory at all, make the ext_len longer.
# Otherwise, make the mem_len longer and keep the ext_len the same.
if args.mem_len == 0:
model.reset_length(args.eval_tgt_len,
args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
else:
model.reset_length(args.eval_tgt_len,
args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)
# Evaluation
total_len, total_loss = 0, 0.
with torch.no_grad():
mems = tuple()
for i, (data, target, seq_len) in enumerate(eval_iter):
if args.max_eval_steps > 0 and i >= args.max_eval_steps:
break
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.float().item()
total_len += seq_len
# Switch back to the training mode
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.train()
return total_loss / total_len
def train():
# Turn on training mode which enables dropout.
global train_step, train_loss, best_val_loss, eval_start_time, log_start_time
model.train()
if args.batch_chunk > 1:
mems = [tuple() for _ in range(args.batch_chunk)]
else:
mems = tuple()
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
for batch, (data, target, seq_len) in enumerate(train_iter):
model.zero_grad()
if args.batch_chunk > 1:
data_chunks = torch.chunk(data, args.batch_chunk, 1)
target_chunks = torch.chunk(target, args.batch_chunk, 1)
for i in range(args.batch_chunk):
data_i = data_chunks[i].contiguous()
target_i = target_chunks[i].contiguous()
ret = para_model(data_i, target_i, *mems[i])
loss, mems[i] = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss) / args.batch_chunk
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()
else:
ret = para_model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.float().mean().type_as(loss)
if args.fp16:
optimizer.backward(loss)
else:
loss.backward()
train_loss += loss.float().item()
if args.fp16:
optimizer.clip_master_grads(args.clip)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
if args.sample_softmax > 0:
optimizer_sparse.step()
# step-wise learning rate annealing
train_step += 1
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
# linear warmup stage
if train_step < args.warmup_step:
curr_lr = args.lr * train_step / args.warmup_step
optimizer.param_groups[0]['lr'] = curr_lr
if args.sample_softmax > 0:
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
else:
if args.scheduler == 'cosine':
scheduler.step(train_step)
if args.sample_softmax > 0:
scheduler_sparse.step(train_step)
elif args.scheduler == 'inv_sqrt':
scheduler.step(train_step)
if train_step % args.log_interval == 0:
cur_loss = train_loss / args.log_interval
elapsed = time.time() - log_start_time
log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
'| ms/batch {:5.2f} | loss {:5.2f}'.format(
epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
elapsed * 1000 / args.log_interval, cur_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
else:
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
logging(log_str)
train_loss = 0
log_start_time = time.time()
if train_step % args.eval_interval == 0:
val_loss = evaluate(va_iter)
logging('-' * 100)
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f}'.format(
train_step // args.eval_interval, train_step,
(time.time() - eval_start_time), val_loss)
if args.dataset in ['enwik8', 'text8']:
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
else:
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
logging(log_str)
logging('-' * 100)
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
if not args.debug:
with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
torch.save(model, f)
with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
torch.save(optimizer.state_dict(), f)
best_val_loss = val_loss
# dev-performance based learning rate annealing
if args.scheduler == 'dev_perf':
scheduler.step(val_loss)
if args.sample_softmax > 0:
scheduler_sparse.step(val_loss)
eval_start_time = time.time()
if train_step == args.max_step:
break
# Loop over epochs.
train_step = 0
train_loss = 0
best_val_loss = None
log_start_time = time.time()
eval_start_time = time.time()
# At any point you can hit Ctrl + C to break out of training early.
try:
for epoch in itertools.count(start=1):
train()
if train_step == args.max_step:
logging('-' * 100)
logging('End of training')
break
except KeyboardInterrupt:
logging('-' * 100)
logging('Exiting from training early')
# Load the best saved model.
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
model = torch.load(f)
para_model = model.to(device)
# Run on test data.
test_loss = evaluate(te_iter)
logging('=' * 100)
if args.dataset in ['enwik8', 'text8']:
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
test_loss, test_loss / math.log(2)))
else:
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
test_loss, math.exp(test_loss)))
logging('=' * 100)