2.6 KiB
Introduction
This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts:
*large.share for the SoTA setting with large models which might not be directly runnable on a local GPU machine.*base.share for the base models which can be run on a few GPUs.
The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments.
Prerequisite
- Pytorch 0.4:
conda install pytorch torchvision -c pytorch
Data Prepration
bash getdata.sh
Training and Evaluation
Replicate the "bpc = 1.06" result on enwik8 with a 12-layer Transformer-XL
-
Make sure the machine have 4 GPUs, each with at least 11G memory
-
Training
bash run_enwik8_base.sh train --work_dir PATH_TO_WORK_DIR -
Evaluation
bash run_enwik8_base.sh eval --work_dir PATH_TO_WORK_DIR
Replicate the "PPL = 24.03" result on wikitext-103 with Transformer-XL
-
Make sure the machine have 4 GPUs, each with at least 11G memory
-
Training
bash run_wt103_base.sh train --work_dir PATH_TO_WORK_DIR -
Evaluation
bash run_wt103_base.sh eval --work_dir PATH_TO_WORK_DIR
Other options:
--batch_chunk: this option allows one to trade speed for memory. Forbatch_chunk > 1, the program will split each training batch intobatch_chunksub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided bybatch_chunk. Hence, the memory usage will propertionally lower while the computation time will inversely higher.--div_val: when using adaptive softmax and embedding, the embedding dimension is divided bydiv_valfrom binito bini+1. This saves both GPU memory and the parameter budget.--fp16and--dynamic-loss-scale: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling.- Note: to explore the
--fp16option, please make sure theapexpackage is installed (https://github.com/NVIDIA/apex/).
- Note: to explore the
- To see performance without the recurrence mechanism, simply use
mem_len=0in all your scripts. - To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use
attn_type=2andmem_len=0.
Other datasets:
Text8character-level language modeling: check outrun_text8_base.shlm1bword-level language modeling: check outrun_lm1b_base.sh