From f32f4678e149968b893854701aef2a276cc5e621 Mon Sep 17 00:00:00 2001 From: Tibo De Peuter Date: Fri, 5 Dec 2025 12:37:48 +0100 Subject: [PATCH] chore: Restructure --- simple/.keep | 0 .../dataset_loaders}/Dataset.py | 0 .../dataset_loaders}/EnWik9.py | 0 .../dataset_loaders}/LoremIpsumDataset.py | 0 .../dataset_loaders}/OpenGenomeDataset.py | 0 .../dataset_loaders}/__init__.py | 0 {models => src/models}/__init__.py | 0 {models => src/models}/cnn/__init__.py | 0 {models => src/models}/cnn/cnn.py | 0 {trainers => src/trainers}/FullTrainer.py | 0 {trainers => src/trainers}/OptunaTrainer.py | 0 {trainers => src/trainers}/__init__.py | 0 {trainers => src/trainers}/train.py | 0 {trainers => src/trainers}/trainer.py | 0 {utils => src/utils}/__init__.py | 0 {utils => src/utils}/utils.py | 0 transformer-xl/LICENSE | 201 - transformer-xl/README.md | 34 - transformer-xl/getdata.sh | 90 - transformer-xl/prep_text8.py | 32 - transformer-xl/pytorch/.DS_Store | Bin 6148 -> 0 bytes transformer-xl/pytorch/README.md | 62 - transformer-xl/pytorch/data_utils.py | 273 -- transformer-xl/pytorch/eval.py | 122 - transformer-xl/pytorch/mem_transformer.py | 812 ---- transformer-xl/pytorch/run_enwik8_base.sh | 41 - transformer-xl/pytorch/run_enwik8_large.sh | 41 - transformer-xl/pytorch/run_lm1b_base.sh | 43 - transformer-xl/pytorch/run_lm1b_large.sh | 43 - transformer-xl/pytorch/run_text8_base.sh | 41 - transformer-xl/pytorch/run_text8_large.sh | 38 - transformer-xl/pytorch/run_wt103_base.sh | 42 - transformer-xl/pytorch/run_wt103_large.sh | 43 - transformer-xl/pytorch/train.py | 562 --- .../pytorch/utils/adaptive_softmax.py | 90 - transformer-xl/pytorch/utils/data_parallel.py | 91 - transformer-xl/pytorch/utils/exp_utils.py | 40 - .../pytorch/utils/log_uniform_sampler.py | 147 - .../pytorch/utils/proj_adaptive_softmax.py | 151 - transformer-xl/pytorch/utils/vocabulary.py | 163 - transformer-xl/tf/README.md | 131 - transformer-xl/tf/avg_checkpoints.py | 118 - transformer-xl/tf/data_utils.py | 586 --- transformer-xl/tf/gpu_utils.py | 65 - transformer-xl/tf/model.py | 546 --- transformer-xl/tf/scripts/enwik8_base_gpu.sh | 102 - transformer-xl/tf/scripts/enwik8_large_tpu.sh | 122 - transformer-xl/tf/scripts/lm1b_base_gpu.sh | 110 - transformer-xl/tf/scripts/lm1b_large_tpu.sh | 136 - transformer-xl/tf/scripts/text8_base_gpu.sh | 102 - transformer-xl/tf/scripts/text8_large_tpu.sh | 122 - transformer-xl/tf/scripts/wt103_base_gpu.sh | 108 - transformer-xl/tf/scripts/wt103_large_tpu.sh | 134 - transformer-xl/tf/sota/download.sh | 87 - transformer-xl/tf/sota/enwik8.sh | 58 - transformer-xl/tf/sota/lm1b.sh | 63 - transformer-xl/tf/sota/text8.sh | 58 - transformer-xl/tf/sota/wt103.sh | 71 - transformer-xl/tf/tpu_estimator.py | 3519 ----------------- transformer-xl/tf/train.py | 462 --- transformer-xl/tf/train_gpu.py | 475 --- transformer-xl/tf/vocabulary.py | 170 - 62 files changed, 10547 deletions(-) delete mode 100644 simple/.keep rename {dataset_loaders => src/dataset_loaders}/Dataset.py (100%) rename {dataset_loaders => src/dataset_loaders}/EnWik9.py (100%) rename {dataset_loaders => src/dataset_loaders}/LoremIpsumDataset.py (100%) rename {dataset_loaders => src/dataset_loaders}/OpenGenomeDataset.py (100%) rename {dataset_loaders => src/dataset_loaders}/__init__.py (100%) rename {models => src/models}/__init__.py (100%) rename {models => src/models}/cnn/__init__.py (100%) rename {models => src/models}/cnn/cnn.py (100%) rename {trainers => src/trainers}/FullTrainer.py (100%) rename {trainers => src/trainers}/OptunaTrainer.py (100%) rename {trainers => src/trainers}/__init__.py (100%) rename {trainers => src/trainers}/train.py (100%) rename {trainers => src/trainers}/trainer.py (100%) rename {utils => src/utils}/__init__.py (100%) rename {utils => src/utils}/utils.py (100%) delete mode 100644 transformer-xl/LICENSE delete mode 100644 transformer-xl/README.md delete mode 100755 transformer-xl/getdata.sh delete mode 100644 transformer-xl/prep_text8.py delete mode 100644 transformer-xl/pytorch/.DS_Store delete mode 100644 transformer-xl/pytorch/README.md delete mode 100644 transformer-xl/pytorch/data_utils.py delete mode 100644 transformer-xl/pytorch/eval.py delete mode 100644 transformer-xl/pytorch/mem_transformer.py delete mode 100644 transformer-xl/pytorch/run_enwik8_base.sh delete mode 100644 transformer-xl/pytorch/run_enwik8_large.sh delete mode 100644 transformer-xl/pytorch/run_lm1b_base.sh delete mode 100644 transformer-xl/pytorch/run_lm1b_large.sh delete mode 100644 transformer-xl/pytorch/run_text8_base.sh delete mode 100644 transformer-xl/pytorch/run_text8_large.sh delete mode 100644 transformer-xl/pytorch/run_wt103_base.sh delete mode 100644 transformer-xl/pytorch/run_wt103_large.sh delete mode 100644 transformer-xl/pytorch/train.py delete mode 100644 transformer-xl/pytorch/utils/adaptive_softmax.py delete mode 100644 transformer-xl/pytorch/utils/data_parallel.py delete mode 100644 transformer-xl/pytorch/utils/exp_utils.py delete mode 100644 transformer-xl/pytorch/utils/log_uniform_sampler.py delete mode 100644 transformer-xl/pytorch/utils/proj_adaptive_softmax.py delete mode 100644 transformer-xl/pytorch/utils/vocabulary.py delete mode 100644 transformer-xl/tf/README.md delete mode 100644 transformer-xl/tf/avg_checkpoints.py delete mode 100644 transformer-xl/tf/data_utils.py delete mode 100644 transformer-xl/tf/gpu_utils.py delete mode 100644 transformer-xl/tf/model.py delete mode 100644 transformer-xl/tf/scripts/enwik8_base_gpu.sh delete mode 100644 transformer-xl/tf/scripts/enwik8_large_tpu.sh delete mode 100644 transformer-xl/tf/scripts/lm1b_base_gpu.sh delete mode 100644 transformer-xl/tf/scripts/lm1b_large_tpu.sh delete mode 100644 transformer-xl/tf/scripts/text8_base_gpu.sh delete mode 100644 transformer-xl/tf/scripts/text8_large_tpu.sh delete mode 100644 transformer-xl/tf/scripts/wt103_base_gpu.sh delete mode 100644 transformer-xl/tf/scripts/wt103_large_tpu.sh delete mode 100644 transformer-xl/tf/sota/download.sh delete mode 100644 transformer-xl/tf/sota/enwik8.sh delete mode 100644 transformer-xl/tf/sota/lm1b.sh delete mode 100644 transformer-xl/tf/sota/text8.sh delete mode 100644 transformer-xl/tf/sota/wt103.sh delete mode 100644 transformer-xl/tf/tpu_estimator.py delete mode 100644 transformer-xl/tf/train.py delete mode 100644 transformer-xl/tf/train_gpu.py delete mode 100644 transformer-xl/tf/vocabulary.py diff --git a/simple/.keep b/simple/.keep deleted file mode 100644 index e69de29..0000000 diff --git a/dataset_loaders/Dataset.py b/src/dataset_loaders/Dataset.py similarity index 100% rename from dataset_loaders/Dataset.py rename to src/dataset_loaders/Dataset.py diff --git a/dataset_loaders/EnWik9.py b/src/dataset_loaders/EnWik9.py similarity index 100% rename from dataset_loaders/EnWik9.py rename to src/dataset_loaders/EnWik9.py diff --git a/dataset_loaders/LoremIpsumDataset.py b/src/dataset_loaders/LoremIpsumDataset.py similarity index 100% rename from dataset_loaders/LoremIpsumDataset.py rename to src/dataset_loaders/LoremIpsumDataset.py diff --git a/dataset_loaders/OpenGenomeDataset.py b/src/dataset_loaders/OpenGenomeDataset.py similarity index 100% rename from dataset_loaders/OpenGenomeDataset.py rename to src/dataset_loaders/OpenGenomeDataset.py diff --git a/dataset_loaders/__init__.py b/src/dataset_loaders/__init__.py similarity index 100% rename from dataset_loaders/__init__.py rename to src/dataset_loaders/__init__.py diff --git a/models/__init__.py b/src/models/__init__.py similarity index 100% rename from models/__init__.py rename to src/models/__init__.py diff --git a/models/cnn/__init__.py b/src/models/cnn/__init__.py similarity index 100% rename from models/cnn/__init__.py rename to src/models/cnn/__init__.py diff --git a/models/cnn/cnn.py b/src/models/cnn/cnn.py similarity index 100% rename from models/cnn/cnn.py rename to src/models/cnn/cnn.py diff --git a/trainers/FullTrainer.py b/src/trainers/FullTrainer.py similarity index 100% rename from trainers/FullTrainer.py rename to src/trainers/FullTrainer.py diff --git a/trainers/OptunaTrainer.py b/src/trainers/OptunaTrainer.py similarity index 100% rename from trainers/OptunaTrainer.py rename to src/trainers/OptunaTrainer.py diff --git a/trainers/__init__.py b/src/trainers/__init__.py similarity index 100% rename from trainers/__init__.py rename to src/trainers/__init__.py diff --git a/trainers/train.py b/src/trainers/train.py similarity index 100% rename from trainers/train.py rename to src/trainers/train.py diff --git a/trainers/trainer.py b/src/trainers/trainer.py similarity index 100% rename from trainers/trainer.py rename to src/trainers/trainer.py diff --git a/utils/__init__.py b/src/utils/__init__.py similarity index 100% rename from utils/__init__.py rename to src/utils/__init__.py diff --git a/utils/utils.py b/src/utils/utils.py similarity index 100% rename from utils/utils.py rename to src/utils/utils.py diff --git a/transformer-xl/LICENSE b/transformer-xl/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/transformer-xl/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/transformer-xl/README.md b/transformer-xl/README.md deleted file mode 100644 index 9f12978..0000000 --- a/transformer-xl/README.md +++ /dev/null @@ -1,34 +0,0 @@ -# Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context - -This repository contains the code in both **PyTorch** and **TensorFlow** for our paper ->[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](http://arxiv.org/abs/1901.02860) - ->Zihang Dai\*, Zhilin Yang\*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution) - ->Preprint 2018 - -## TensorFlow - -- The source code is in the `tf/` folder, supporting (1) single-node multi-gpu training, and (2) multi-host TPU training. -- Besides the source code, we also provide pretrained "TensorFlow" models with state-of-the-art (SoTA) performances reported in the paper. -- Please refer to `tf/README.md` for details. - -## PyTorch - -- The source code is in the `pytorch/` folder, supporting single-node multi-gpu training via the module `nn.DataParallel`. -- Please refer to `pytorch/README.md` for details. - -## Results - -Transformer-XL achieves new state-of-the-art results on multiple language modeling benchmarks. Transformer-XL is also the first to break through the 1.0 barrier on char-level language modeling. Below is a summary. - -Method | enwiki8 | text8 | One Billion Word | WT-103 | PTB (w/o finetuning) --- | -- | -- | -- | -- | -- -Previous Best | 1.06 | 1.13 | 23.7 | 20.5 | 55.5 -Transformer-XL | **0.99** | **1.08** | **21.8** | **18.3** | **54.5** - - - -## Acknowledgement - -A large portion of the `getdata.sh` script comes from the [awd-lstm](https://github.com/salesforce/awd-lstm-lm/) repo. Happy Language Modeling :) diff --git a/transformer-xl/getdata.sh b/transformer-xl/getdata.sh deleted file mode 100755 index 7804757..0000000 --- a/transformer-xl/getdata.sh +++ /dev/null @@ -1,90 +0,0 @@ -echo "=== Acquiring datasets ===" -echo "---" - -mkdir -p data -cd data - -if [[ ! -d 'wikitext-2' ]]; then - echo "- Downloading WikiText-2 (WT2)" - wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip - unzip -q wikitext-2-v1.zip - cd wikitext-2 - mv wiki.train.tokens train.txt - mv wiki.valid.tokens valid.txt - mv wiki.test.tokens test.txt - cd .. -fi - -echo "- Downloading WikiText-103 (WT2)" -if [[ ! -d 'wikitext-103' ]]; then - wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip - unzip -q wikitext-103-v1.zip - cd wikitext-103 - mv wiki.train.tokens train.txt - mv wiki.valid.tokens valid.txt - mv wiki.test.tokens test.txt - cd .. -fi - -echo "- Downloading enwik8 (Character)" -if [[ ! -d 'enwik8' ]]; then - mkdir -p enwik8 - cd enwik8 - wget --continue http://mattmahoney.net/dc/enwik8.zip - wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py - python3 prep_enwik8.py - cd .. -fi - -echo "- Downloading text8 (Character)" -if [[ ! -d 'text8' ]]; then - mkdir -p text8 - cd text8 - wget --continue http://mattmahoney.net/dc/text8.zip - python ../../prep_text8.py - cd .. -fi - -echo "- Downloading Penn Treebank (PTB)" -if [[ ! -d 'penn' ]]; then - wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz - tar -xzf simple-examples.tgz - - mkdir -p penn - cd penn - mv ../simple-examples/data/ptb.train.txt train.txt - mv ../simple-examples/data/ptb.test.txt test.txt - mv ../simple-examples/data/ptb.valid.txt valid.txt - cd .. - - echo "- Downloading Penn Treebank (Character)" - mkdir -p pennchar - cd pennchar - mv ../simple-examples/data/ptb.char.train.txt train.txt - mv ../simple-examples/data/ptb.char.test.txt test.txt - mv ../simple-examples/data/ptb.char.valid.txt valid.txt - cd .. - - rm -rf simple-examples/ -fi - -echo "- Downloading 1B words" - -if [[ ! -d 'one-billion-words' ]]; then - mkdir -p one-billion-words - cd one-billion-words - - wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz - tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz - - path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/" - cat ${path}/news.en.heldout-00000-of-00050 > valid.txt - cat ${path}/news.en.heldout-00000-of-00050 > test.txt - - wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt - - cd .. -fi - -echo "---" -echo "Happy language modeling :)" diff --git a/transformer-xl/prep_text8.py b/transformer-xl/prep_text8.py deleted file mode 100644 index 65b1ce7..0000000 --- a/transformer-xl/prep_text8.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 - -import os -import sys -import zipfile - -from io import open - -if os.path.exists('train.txt'): - print('Tokenized text8 already exists - skipping processing') - sys.exit() - -data = zipfile.ZipFile('text8.zip').extractall() -data = open('text8', 'r', encoding='utf-8').read() - -print('Length of text8: {}'.format(len(data))) - -num_test_chars = 5000000 - -train_data = data[: -2 * num_test_chars] -valid_data = data[-2 * num_test_chars: -num_test_chars] -test_data = data[-num_test_chars:] - -for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]: - print('{} will have {} bytes'.format(fn, len(part))) - print('- Tokenizing...') - # Change space ' ' to underscore '_' - part_str = ' '.join(['_' if c == ' ' else c for c in part.strip()]) - print('- Writing...') - f = open(fn, 'w').write(part_str) - f = open(fn + '.raw', 'w', encoding='utf-8').write(part) diff --git a/transformer-xl/pytorch/.DS_Store b/transformer-xl/pytorch/.DS_Store deleted file mode 100644 index 5008ddfcf53c02e82d7eee2e57c38e5672ef89f6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeH~Jr2S!425mzP>H1@V-^m;4Wg<&0T*E43hX&L&p$$qDprKhvt+--jT7}7np#A3 zem<@ulZcFPQ@L2!n>{z**++&mCkOWA81W14cNZlEfg7;MkzE(HCqgga^y>{tEnwC%0;vJ&^%eQ zLs35+`xjp>T0 1`, the program will split each training batch into `batch_chunk` sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by `batch_chunk`. Hence, the memory usage will propertionally lower while the computation time will inversely higher. -- `--div_val`: when using adaptive softmax and embedding, the embedding dimension is divided by `div_val` from bin $i$ to bin $i+1$. This saves both GPU memory and the parameter budget. -- `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling. - - Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/). -- To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts. -- To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use `attn_type=2` and `mem_len=0`. - - -#### Other datasets: - -- `Text8` character-level language modeling: check out `run_text8_base.sh` -- `lm1b` word-level language modeling: check out `run_lm1b_base.sh` diff --git a/transformer-xl/pytorch/data_utils.py b/transformer-xl/pytorch/data_utils.py deleted file mode 100644 index df762a7..0000000 --- a/transformer-xl/pytorch/data_utils.py +++ /dev/null @@ -1,273 +0,0 @@ -import os, sys -import glob - -from collections import Counter, OrderedDict -import numpy as np -import torch - -from utils.vocabulary import Vocab - -class LMOrderedIterator(object): - def __init__(self, data, bsz, bptt, device='cpu', ext_len=None): - """ - data -- LongTensor -- the LongTensor is strictly ordered - """ - self.bsz = bsz - self.bptt = bptt - self.ext_len = ext_len if ext_len is not None else 0 - - self.device = device - - # Work out how cleanly we can divide the dataset into bsz parts. - self.n_step = data.size(0) // bsz - - # Trim off any extra elements that wouldn't cleanly fit (remainders). - data = data.narrow(0, 0, self.n_step * bsz) - - # Evenly divide the data across the bsz batches. - self.data = data.view(bsz, -1).t().contiguous().to(device) - - # Number of mini-batches - self.n_batch = (self.n_step + self.bptt - 1) // self.bptt - - def get_batch(self, i, bptt=None): - if bptt is None: bptt = self.bptt - seq_len = min(bptt, self.data.size(0) - 1 - i) - - end_idx = i + seq_len - beg_idx = max(0, i - self.ext_len) - - data = self.data[beg_idx:end_idx] - target = self.data[i+1:i+1+seq_len] - - return data, target, seq_len - - def get_fixlen_iter(self, start=0): - for i in range(start, self.data.size(0) - 1, self.bptt): - yield self.get_batch(i) - - def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): - max_len = self.bptt + max_deviation * std - i = start - while True: - bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2. - bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std)))) - data, target, seq_len = self.get_batch(i, bptt) - i += seq_len - yield data, target, seq_len - if i >= self.data.size(0) - 2: - break - - def __iter__(self): - return self.get_fixlen_iter() - - -class LMShuffledIterator(object): - def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False): - """ - data -- list[LongTensor] -- there is no order among the LongTensors - """ - self.data = data - - self.bsz = bsz - self.bptt = bptt - self.ext_len = ext_len if ext_len is not None else 0 - - self.device = device - self.shuffle = shuffle - - def get_sent_stream(self): - # index iterator - epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \ - else np.array(range(len(self.data))) - - # sentence iterator - for idx in epoch_indices: - yield self.data[idx] - - def stream_iterator(self, sent_stream): - # streams for each data in the batch - streams = [None] * self.bsz - - data = torch.LongTensor(self.bptt, self.bsz) - target = torch.LongTensor(self.bptt, self.bsz) - - n_retain = 0 - - while True: - # data : [n_retain+bptt x bsz] - # target : [bptt x bsz] - data[n_retain:].fill_(-1) - target.fill_(-1) - - valid_batch = True - - for i in range(self.bsz): - n_filled = 0 - try: - while n_filled < self.bptt: - if streams[i] is None or len(streams[i]) <= 1: - streams[i] = next(sent_stream) - # number of new tokens to fill in - n_new = min(len(streams[i]) - 1, self.bptt - n_filled) - # first n_retain tokens are retained from last batch - data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \ - streams[i][:n_new] - target[n_filled:n_filled+n_new, i] = \ - streams[i][1:n_new+1] - streams[i] = streams[i][n_new:] - n_filled += n_new - except StopIteration: - valid_batch = False - break - - if not valid_batch: - return - - data = data.to(self.device) - target = target.to(self.device) - - yield data, target, self.bptt - - n_retain = min(data.size(0), self.ext_len) - if n_retain > 0: - data[:n_retain] = data[-n_retain:] - data.resize_(n_retain + self.bptt, data.size(1)) - - def __iter__(self): - # sent_stream is an iterator - sent_stream = self.get_sent_stream() - - for batch in self.stream_iterator(sent_stream): - yield batch - - -class LMMultiFileIterator(LMShuffledIterator): - def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None, - shuffle=False): - - self.paths = paths - self.vocab = vocab - - self.bsz = bsz - self.bptt = bptt - self.ext_len = ext_len if ext_len is not None else 0 - - self.device = device - self.shuffle = shuffle - - def get_sent_stream(self, path): - sents = self.vocab.encode_file(path, add_double_eos=True) - if self.shuffle: - np.random.shuffle(sents) - sent_stream = iter(sents) - - return sent_stream - - def __iter__(self): - if self.shuffle: - np.random.shuffle(self.paths) - - for path in self.paths: - # sent_stream is an iterator - sent_stream = self.get_sent_stream(path) - for batch in self.stream_iterator(sent_stream): - yield batch - - -class Corpus(object): - def __init__(self, path, dataset, *args, **kwargs): - self.dataset = dataset - self.vocab = Vocab(*args, **kwargs) - - if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']: - self.vocab.count_file(os.path.join(path, 'train.txt')) - self.vocab.count_file(os.path.join(path, 'valid.txt')) - self.vocab.count_file(os.path.join(path, 'test.txt')) - elif self.dataset == 'wt103': - self.vocab.count_file(os.path.join(path, 'train.txt')) - elif self.dataset == 'lm1b': - train_path_pattern = os.path.join( - path, '1-billion-word-language-modeling-benchmark-r13output', - 'training-monolingual.tokenized.shuffled', 'news.en-*') - train_paths = glob.glob(train_path_pattern) - # the vocab will load from file when build_vocab() is called - - self.vocab.build_vocab() - - if self.dataset in ['ptb', 'wt2', 'wt103']: - self.train = self.vocab.encode_file( - os.path.join(path, 'train.txt'), ordered=True) - self.valid = self.vocab.encode_file( - os.path.join(path, 'valid.txt'), ordered=True) - self.test = self.vocab.encode_file( - os.path.join(path, 'test.txt'), ordered=True) - elif self.dataset in ['enwik8', 'text8']: - self.train = self.vocab.encode_file( - os.path.join(path, 'train.txt'), ordered=True, add_eos=False) - self.valid = self.vocab.encode_file( - os.path.join(path, 'valid.txt'), ordered=True, add_eos=False) - self.test = self.vocab.encode_file( - os.path.join(path, 'test.txt'), ordered=True, add_eos=False) - elif self.dataset == 'lm1b': - self.train = train_paths - self.valid = self.vocab.encode_file( - os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True) - self.test = self.vocab.encode_file( - os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True) - - def get_iterator(self, split, *args, **kwargs): - if split == 'train': - if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: - data_iter = LMOrderedIterator(self.train, *args, **kwargs) - elif self.dataset == 'lm1b': - kwargs['shuffle'] = True - data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) - elif split in ['valid', 'test']: - data = self.valid if split == 'valid' else self.test - if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']: - data_iter = LMOrderedIterator(data, *args, **kwargs) - elif self.dataset == 'lm1b': - data_iter = LMShuffledIterator(data, *args, **kwargs) - - return data_iter - - -def get_lm_corpus(datadir, dataset): - fn = os.path.join(datadir, 'cache.pt') - if os.path.exists(fn): - print('Loading cached dataset...') - corpus = torch.load(fn) - else: - print('Producing dataset {}...'.format(dataset)) - kwargs = {} - if dataset in ['wt103', 'wt2']: - kwargs['special'] = [''] - kwargs['lower_case'] = False - elif dataset == 'ptb': - kwargs['special'] = [''] - kwargs['lower_case'] = True - elif dataset == 'lm1b': - kwargs['special'] = [] - kwargs['lower_case'] = False - kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt') - elif dataset in ['enwik8', 'text8']: - pass - - corpus = Corpus(datadir, dataset, **kwargs) - torch.save(corpus, fn) - - return corpus - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser(description='unit test') - parser.add_argument('--datadir', type=str, default='../data/text8', - help='location of the data corpus') - parser.add_argument('--dataset', type=str, default='text8', - choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'], - help='dataset name') - args = parser.parse_args() - - corpus = get_lm_corpus(args.datadir, args.dataset) - print('Vocab size : {}'.format(len(corpus.vocab.idx2sym))) diff --git a/transformer-xl/pytorch/eval.py b/transformer-xl/pytorch/eval.py deleted file mode 100644 index eff3618..0000000 --- a/transformer-xl/pytorch/eval.py +++ /dev/null @@ -1,122 +0,0 @@ -# coding: utf-8 -import argparse -import time -import math -import os, sys - -import torch - -from data_utils import get_lm_corpus -from mem_transformer import MemTransformerLM -from utils.exp_utils import get_logger - -parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') -parser.add_argument('--data', type=str, default='../data/wikitext-103', - help='location of the data corpus') -parser.add_argument('--dataset', type=str, default='wt103', - choices=['wt103', 'lm1b', 'enwik8', 'text8'], - help='dataset name') -parser.add_argument('--split', type=str, default='all', - choices=['all', 'valid', 'test'], - help='which split to evaluate') -parser.add_argument('--batch_size', type=int, default=10, - help='batch size') -parser.add_argument('--tgt_len', type=int, default=5, - help='number of tokens to predict') -parser.add_argument('--ext_len', type=int, default=0, - help='length of the extended context') -parser.add_argument('--mem_len', type=int, default=0, - help='length of the retained previous heads') -parser.add_argument('--clamp_len', type=int, default=-1, - help='max positional embedding index') -parser.add_argument('--cuda', action='store_true', - help='use CUDA') -parser.add_argument('--work_dir', type=str, required=True, - help='path to the work_dir') -parser.add_argument('--no_log', action='store_true', - help='do not log the eval result') -parser.add_argument('--same_length', action='store_true', - help='set same length attention with masking') -args = parser.parse_args() -assert args.ext_len >= 0, 'extended context length must be non-negative' - -device = torch.device("cuda" if args.cuda else "cpu") - -# Get logger -logging = get_logger(os.path.join(args.work_dir, 'log.txt'), - log_=not args.no_log) - -# Load dataset -corpus = get_lm_corpus(args.data, args.dataset) -ntokens = len(corpus.vocab) - -va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, - device=device, ext_len=args.ext_len) -te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, - device=device, ext_len=args.ext_len) - -# Load the best saved model. -with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: - model = torch.load(f) -model.backward_compatible() -model = model.to(device) - -logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( - args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) - -model.reset_length(args.tgt_len, args.ext_len, args.mem_len) -if args.clamp_len > 0: - model.clamp_len = args.clamp_len -if args.same_length: - model.same_length = True - -############################################################################### -# Evaluation code -############################################################################### -def evaluate(eval_iter): - # Turn on evaluation mode which disables dropout. - model.eval() - total_len, total_loss = 0, 0. - start_time = time.time() - with torch.no_grad(): - mems = tuple() - for idx, (data, target, seq_len) in enumerate(eval_iter): - ret = model(data, target, *mems) - loss, mems = ret[0], ret[1:] - loss = loss.mean() - total_loss += seq_len * loss.item() - total_len += seq_len - total_time = time.time() - start_time - logging('Time : {:.2f}s, {:.2f}ms/segment'.format( - total_time, 1000 * total_time / (idx+1))) - return total_loss / total_len - -# Run on test data. -if args.split == 'all': - test_loss = evaluate(te_iter) - valid_loss = evaluate(va_iter) -elif args.split == 'valid': - valid_loss = evaluate(va_iter) - test_loss = None -elif args.split == 'test': - test_loss = evaluate(te_iter) - valid_loss = None - -def format_log(loss, split): - if args.dataset in ['enwik8', 'text8']: - log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format( - split, loss, loss / math.log(2)) - else: - log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( - split, loss, math.exp(loss)) - return log_str - -log_str = '' -if valid_loss is not None: - log_str += format_log(valid_loss, 'valid') -if test_loss is not None: - log_str += format_log(test_loss, 'test') - -logging('=' * 100) -logging(log_str) -logging('=' * 100) diff --git a/transformer-xl/pytorch/mem_transformer.py b/transformer-xl/pytorch/mem_transformer.py deleted file mode 100644 index ed02ee9..0000000 --- a/transformer-xl/pytorch/mem_transformer.py +++ /dev/null @@ -1,812 +0,0 @@ -import sys -import math -import functools - -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -sys.path.append('utils') -from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax -from log_uniform_sampler import LogUniformSampler, sample_logits - -class PositionalEmbedding(nn.Module): - def __init__(self, demb): - super(PositionalEmbedding, self).__init__() - - self.demb = demb - - inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) - self.register_buffer('inv_freq', inv_freq) - - def forward(self, pos_seq, bsz=None): - sinusoid_inp = torch.ger(pos_seq, self.inv_freq) - pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) - - if bsz is not None: - return pos_emb[:,None,:].expand(-1, bsz, -1) - else: - return pos_emb[:,None,:] - - -class PositionwiseFF(nn.Module): - def __init__(self, d_model, d_inner, dropout, pre_lnorm=False): - super(PositionwiseFF, self).__init__() - - self.d_model = d_model - self.d_inner = d_inner - self.dropout = dropout - - self.CoreNet = nn.Sequential( - nn.Linear(d_model, d_inner), nn.ReLU(inplace=True), - nn.Dropout(dropout), - nn.Linear(d_inner, d_model), - nn.Dropout(dropout), - ) - - self.layer_norm = nn.LayerNorm(d_model) - - self.pre_lnorm = pre_lnorm - - def forward(self, inp): - if self.pre_lnorm: - ##### layer normalization + positionwise feed-forward - core_out = self.CoreNet(self.layer_norm(inp)) - - ##### residual connection - output = core_out + inp - else: - ##### positionwise feed-forward - core_out = self.CoreNet(inp) - - ##### residual connection + layer normalization - output = self.layer_norm(inp + core_out) - - return output - -class MultiHeadAttn(nn.Module): - def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, - pre_lnorm=False): - super(MultiHeadAttn, self).__init__() - - self.n_head = n_head - self.d_model = d_model - self.d_head = d_head - self.dropout = dropout - - self.q_net = nn.Linear(d_model, n_head * d_head, bias=False) - self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False) - - self.drop = nn.Dropout(dropout) - self.dropatt = nn.Dropout(dropatt) - self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) - - self.layer_norm = nn.LayerNorm(d_model) - - self.scale = 1 / (d_head ** 0.5) - - self.pre_lnorm = pre_lnorm - - def forward(self, h, attn_mask=None, mems=None): - ##### multihead attention - # [hlen x bsz x n_head x d_head] - - if mems is not None: - c = torch.cat([mems, h], 0) - else: - c = h - - if self.pre_lnorm: - ##### layer normalization - c = self.layer_norm(c) - - head_q = self.q_net(h) - head_k, head_v = torch.chunk(self.kv_net(c), 2, -1) - - head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head) - head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head) - head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head) - - # [qlen x klen x bsz x n_head] - attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k)) - attn_score.mul_(self.scale) - if attn_mask is not None and attn_mask.any().item(): - if attn_mask.dim() == 2: - attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) - elif attn_mask.dim() == 3: - attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) - - # [qlen x klen x bsz x n_head] - attn_prob = F.softmax(attn_score, dim=1) - attn_prob = self.dropatt(attn_prob) - - # [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head] - attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v)) - attn_vec = attn_vec.contiguous().view( - attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) - - ##### linear projection - attn_out = self.o_net(attn_vec) - attn_out = self.drop(attn_out) - - if self.pre_lnorm: - ##### residual connection - output = h + attn_out - else: - ##### residual connection + layer normalization - output = self.layer_norm(h + attn_out) - - return output - -class RelMultiHeadAttn(nn.Module): - def __init__(self, n_head, d_model, d_head, dropout, dropatt=0, - tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False): - super(RelMultiHeadAttn, self).__init__() - - self.n_head = n_head - self.d_model = d_model - self.d_head = d_head - self.dropout = dropout - - self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False) - - self.drop = nn.Dropout(dropout) - self.dropatt = nn.Dropout(dropatt) - self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) - - self.layer_norm = nn.LayerNorm(d_model) - - self.scale = 1 / (d_head ** 0.5) - - self.pre_lnorm = pre_lnorm - - def _parallelogram_mask(self, h, w, left=False): - mask = torch.ones((h, w)).byte() - m = min(h, w) - mask[:m,:m] = torch.triu(mask[:m,:m]) - mask[-m:,-m:] = torch.tril(mask[-m:,-m:]) - - if left: - return mask - else: - return mask.flip(0) - - def _shift(self, x, qlen, klen, mask, left=False): - if qlen > 1: - zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)), - device=x.DEVICE, dtype=x.dtype) - else: - zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype) - - if left: - mask = mask.flip(1) - x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1) - else: - x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1) - - x = x_padded.masked_select(mask[:,:,None,None]) \ - .view(qlen, klen, x.size(2), x.size(3)) - - return x - - def _rel_shift(self, x, zero_triu=False): - zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), - device=x.DEVICE, dtype=x.dtype) - x_padded = torch.cat([zero_pad, x], dim=1) - - x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) - - x = x_padded[1:].view_as(x) - - if zero_triu: - ones = torch.ones((x.size(0), x.size(1))) - x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None] - - return x - - def forward(self, w, r, attn_mask=None, mems=None): - raise NotImplementedError - -class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn): - def __init__(self, *args, **kwargs): - super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs) - - self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False) - - def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None): - qlen, rlen, bsz = w.size(0), r.size(0), w.size(1) - - if mems is not None: - cat = torch.cat([mems, w], 0) - if self.pre_lnorm: - w_heads = self.qkv_net(self.layer_norm(cat)) - else: - w_heads = self.qkv_net(cat) - r_head_k = self.r_net(r) - - w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) - w_head_q = w_head_q[-qlen:] - else: - if self.pre_lnorm: - w_heads = self.qkv_net(self.layer_norm(w)) - else: - w_heads = self.qkv_net(w) - r_head_k = self.r_net(r) - - w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) - - klen = w_head_k.size(0) - - w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head - w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head - w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head - - r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head - - #### compute attention score - rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head - AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head - - rr_head_q = w_head_q + r_r_bias - BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head - BD = self._rel_shift(BD) - - # [qlen x klen x bsz x n_head] - attn_score = AC + BD - attn_score.mul_(self.scale) - - #### compute attention probability - if attn_mask is not None and attn_mask.any().item(): - if attn_mask.dim() == 2: - attn_score = attn_score.float().masked_fill( - attn_mask[None,:,:,None], -float('inf')).type_as(attn_score) - elif attn_mask.dim() == 3: - attn_score = attn_score.float().masked_fill( - attn_mask[:,:,:,None], -float('inf')).type_as(attn_score) - - # [qlen x klen x bsz x n_head] - attn_prob = F.softmax(attn_score, dim=1) - attn_prob = self.dropatt(attn_prob) - - #### compute attention vector - attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) - - # [qlen x bsz x n_head x d_head] - attn_vec = attn_vec.contiguous().view( - attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) - - ##### linear projection - attn_out = self.o_net(attn_vec) - attn_out = self.drop(attn_out) - - if self.pre_lnorm: - ##### residual connection - output = w + attn_out - else: - ##### residual connection + layer normalization - output = self.layer_norm(w + attn_out) - - return output - -class RelLearnableMultiHeadAttn(RelMultiHeadAttn): - def __init__(self, *args, **kwargs): - super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs) - - def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None): - # r_emb: [klen, n_head, d_head], used for term B - # r_w_bias: [n_head, d_head], used for term C - # r_bias: [klen, n_head], used for term D - - qlen, bsz = w.size(0), w.size(1) - - if mems is not None: - cat = torch.cat([mems, w], 0) - if self.pre_lnorm: - w_heads = self.qkv_net(self.layer_norm(cat)) - else: - w_heads = self.qkv_net(cat) - w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) - - w_head_q = w_head_q[-qlen:] - else: - if self.pre_lnorm: - w_heads = self.qkv_net(self.layer_norm(w)) - else: - w_heads = self.qkv_net(w) - w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1) - - klen = w_head_k.size(0) - - w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) - w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) - w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) - - if klen > r_emb.size(0): - r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1) - r_emb = torch.cat([r_emb_pad, r_emb], 0) - r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1) - r_bias = torch.cat([r_bias_pad, r_bias], 0) - else: - r_emb = r_emb[-klen:] - r_bias = r_bias[-klen:] - - #### compute attention score - rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head - - AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head - B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head - D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head - BD = self._rel_shift(B_ + D_) - - # [qlen x klen x bsz x n_head] - attn_score = AC + BD - attn_score.mul_(self.scale) - - #### compute attention probability - if attn_mask is not None and attn_mask.any().item(): - if attn_mask.dim() == 2: - attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf')) - elif attn_mask.dim() == 3: - attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf')) - - # [qlen x klen x bsz x n_head] - attn_prob = F.softmax(attn_score, dim=1) - attn_prob = self.dropatt(attn_prob) - - #### compute attention vector - attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v)) - - # [qlen x bsz x n_head x d_head] - attn_vec = attn_vec.contiguous().view( - attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head) - - ##### linear projection - attn_out = self.o_net(attn_vec) - attn_out = self.drop(attn_out) - - if self.pre_lnorm: - ##### residual connection - output = w + attn_out - else: - ##### residual connection + layer normalization - output = self.layer_norm(w + attn_out) - - return output - -class DecoderLayer(nn.Module): - def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs): - super(DecoderLayer, self).__init__() - - self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs) - self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, - pre_lnorm=kwargs.get('pre_lnorm')) - - def forward(self, dec_inp, dec_attn_mask=None, mems=None): - - output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask, - mems=mems) - output = self.pos_ff(output) - - return output - -class RelLearnableDecoderLayer(nn.Module): - def __init__(self, n_head, d_model, d_head, d_inner, dropout, - **kwargs): - super(RelLearnableDecoderLayer, self).__init__() - - self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout, - **kwargs) - self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, - pre_lnorm=kwargs.get('pre_lnorm')) - - def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None): - - output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias, - attn_mask=dec_attn_mask, - mems=mems) - output = self.pos_ff(output) - - return output - -class RelPartialLearnableDecoderLayer(nn.Module): - def __init__(self, n_head, d_model, d_head, d_inner, dropout, - **kwargs): - super(RelPartialLearnableDecoderLayer, self).__init__() - - self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model, - d_head, dropout, **kwargs) - self.pos_ff = PositionwiseFF(d_model, d_inner, dropout, - pre_lnorm=kwargs.get('pre_lnorm')) - - def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None): - - output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias, - attn_mask=dec_attn_mask, - mems=mems) - output = self.pos_ff(output) - - return output - - -class AdaptiveEmbedding(nn.Module): - def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, - sample_softmax=False): - super(AdaptiveEmbedding, self).__init__() - - self.n_token = n_token - self.d_embed = d_embed - - self.cutoffs = cutoffs + [n_token] - self.div_val = div_val - self.d_proj = d_proj - - self.emb_scale = d_proj ** 0.5 - - self.cutoff_ends = [0] + self.cutoffs - - self.emb_layers = nn.ModuleList() - self.emb_projs = nn.ParameterList() - if div_val == 1: - self.emb_layers.append( - nn.Embedding(n_token, d_embed, sparse=sample_softmax>0) - ) - if d_proj != d_embed: - self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed))) - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] - d_emb_i = d_embed // (div_val ** i) - self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i)) - self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i))) - - def forward(self, inp): - if self.div_val == 1: - embed = self.emb_layers[0](inp) - if self.d_proj != self.d_embed: - embed = F.linear(embed, self.emb_projs[0]) - else: - param = next(self.parameters()) - inp_flat = inp.view(-1) - emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], - dtype=param.dtype, device=param.device) - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - - mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) - indices_i = mask_i.nonzero().squeeze() - - if indices_i.numel() == 0: - continue - - inp_i = inp_flat.index_select(0, indices_i) - l_idx - emb_i = self.emb_layers[i](inp_i) - emb_i = F.linear(emb_i, self.emb_projs[i]) - - emb_flat.index_copy_(0, indices_i, emb_i) - - embed = emb_flat.view(*inp.size(), self.d_proj) - - embed.mul_(self.emb_scale) - - return embed - -class MemTransformerLM(nn.Module): - def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner, - dropout, dropatt, tie_weight=True, d_embed=None, - div_val=1, tie_projs=[False], pre_lnorm=False, - tgt_len=None, ext_len=None, mem_len=None, - cutoffs=[], adapt_inp=False, - same_length=False, attn_type=0, clamp_len=-1, - sample_softmax=-1): - super(MemTransformerLM, self).__init__() - self.n_token = n_token - - d_embed = d_model if d_embed is None else d_embed - self.d_embed = d_embed - self.d_model = d_model - self.n_head = n_head - self.d_head = d_head - - self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs, - div_val=div_val) - - self.drop = nn.Dropout(dropout) - - self.n_layer = n_layer - - self.tgt_len = tgt_len - self.mem_len = mem_len - self.ext_len = ext_len - self.max_klen = tgt_len + ext_len + mem_len - - self.attn_type = attn_type - - self.layers = nn.ModuleList() - if attn_type == 0: # the default attention - for i in range(n_layer): - self.layers.append( - RelPartialLearnableDecoderLayer( - n_head, d_model, d_head, d_inner, dropout, - tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, - dropatt=dropatt, pre_lnorm=pre_lnorm) - ) - elif attn_type == 1: # learnable embeddings - for i in range(n_layer): - self.layers.append( - RelLearnableDecoderLayer( - n_head, d_model, d_head, d_inner, dropout, - tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, - dropatt=dropatt, pre_lnorm=pre_lnorm) - ) - elif attn_type in [2, 3]: # absolute embeddings - for i in range(n_layer): - self.layers.append( - DecoderLayer( - n_head, d_model, d_head, d_inner, dropout, - dropatt=dropatt, pre_lnorm=pre_lnorm) - ) - - self.sample_softmax = sample_softmax - # use sampled softmax - if sample_softmax > 0: - self.out_layer = nn.Linear(d_model, n_token) - if tie_weight: - self.out_layer.weight = self.word_emb.weight - self.tie_weight = tie_weight - self.sampler = LogUniformSampler(n_token, sample_softmax) - - # use adaptive softmax (including standard softmax) - else: - self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model, - cutoffs, div_val=div_val) - - if tie_weight: - for i in range(len(self.crit.out_layers)): - self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight - - if tie_projs: - for i, tie_proj in enumerate(tie_projs): - if tie_proj and div_val == 1 and d_model != d_embed: - self.crit.out_projs[i] = self.word_emb.emb_projs[0] - elif tie_proj and div_val != 1: - self.crit.out_projs[i] = self.word_emb.emb_projs[i] - - self.same_length = same_length - self.clamp_len = clamp_len - - self._create_params() - - def backward_compatible(self): - self.sample_softmax = -1 - - def _create_params(self): - if self.attn_type == 0: # default attention - self.pos_emb = PositionalEmbedding(self.d_model) - self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) - self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head)) - elif self.attn_type == 1: # learnable - self.r_emb = nn.Parameter(torch.Tensor( - self.n_layer, self.max_klen, self.n_head, self.d_head)) - self.r_w_bias = nn.Parameter(torch.Tensor( - self.n_layer, self.n_head, self.d_head)) - self.r_bias = nn.Parameter(torch.Tensor( - self.n_layer, self.max_klen, self.n_head)) - elif self.attn_type == 2: # absolute standard - self.pos_emb = PositionalEmbedding(self.d_model) - elif self.attn_type == 3: # absolute deeper SA - self.r_emb = nn.Parameter(torch.Tensor( - self.n_layer, self.max_klen, self.n_head, self.d_head)) - - def reset_length(self, tgt_len, ext_len, mem_len): - self.tgt_len = tgt_len - self.mem_len = mem_len - self.ext_len = ext_len - - def init_mems(self): - if self.mem_len > 0: - mems = [] - param = next(self.parameters()) - for i in range(self.n_layer+1): - empty = torch.empty(0, dtype=param.dtype, device=param.device) - mems.append(empty) - - return mems - else: - return None - - def _update_mems(self, hids, mems, qlen, mlen): - # does not deal with None - if mems is None: return None - - # mems is not None - assert len(hids) == len(mems), 'len(hids) != len(mems)' - - # There are `mlen + qlen` steps that can be cached into mems - # For the next step, the last `ext_len` of the `qlen` tokens - # will be used as the extended context. Hence, we only cache - # the tokens from `mlen + qlen - self.ext_len - self.mem_len` - # to `mlen + qlen - self.ext_len`. - with torch.no_grad(): - new_mems = [] - end_idx = mlen + max(0, qlen - 0 - self.ext_len) - beg_idx = max(0, end_idx - self.mem_len) - for i in range(len(hids)): - - cat = torch.cat([mems[i], hids[i]], dim=0) - new_mems.append(cat[beg_idx:end_idx].detach()) - - return new_mems - - def _forward(self, dec_inp, mems=None): - qlen, bsz = dec_inp.size() - - word_emb = self.word_emb(dec_inp) - - mlen = mems[0].size(0) if mems is not None else 0 - klen = mlen + qlen - if self.same_length: - all_ones = word_emb.new_ones(qlen, klen) - mask_len = klen - self.mem_len - if mask_len > 0: - mask_shift_len = qlen - mask_len - else: - mask_shift_len = qlen - dec_attn_mask = (torch.triu(all_ones, 1+mlen) - + torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1 - else: - dec_attn_mask = torch.triu( - word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None] - - hids = [] - if self.attn_type == 0: # default - pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE, - dtype=word_emb.dtype) - if self.clamp_len > 0: - pos_seq.clamp_(max=self.clamp_len) - pos_emb = self.pos_emb(pos_seq) - - core_out = self.drop(word_emb) - pos_emb = self.drop(pos_emb) - - hids.append(core_out) - for i, layer in enumerate(self.layers): - mems_i = None if mems is None else mems[i] - core_out = layer(core_out, pos_emb, self.r_w_bias, - self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) - hids.append(core_out) - elif self.attn_type == 1: # learnable - core_out = self.drop(word_emb) - hids.append(core_out) - for i, layer in enumerate(self.layers): - if self.clamp_len > 0: - r_emb = self.r_emb[i][-self.clamp_len :] - r_bias = self.r_bias[i][-self.clamp_len :] - else: - r_emb, r_bias = self.r_emb[i], self.r_bias[i] - - mems_i = None if mems is None else mems[i] - core_out = layer(core_out, r_emb, self.r_w_bias[i], - r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i) - hids.append(core_out) - elif self.attn_type == 2: # absolute - pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE, - dtype=word_emb.dtype) - if self.clamp_len > 0: - pos_seq.clamp_(max=self.clamp_len) - pos_emb = self.pos_emb(pos_seq) - - core_out = self.drop(word_emb + pos_emb[-qlen:]) - - hids.append(core_out) - for i, layer in enumerate(self.layers): - mems_i = None if mems is None else mems[i] - if mems_i is not None and i == 0: - mems_i += pos_emb[:mlen] - core_out = layer(core_out, dec_attn_mask=dec_attn_mask, - mems=mems_i) - hids.append(core_out) - elif self.attn_type == 3: - core_out = self.drop(word_emb) - - hids.append(core_out) - for i, layer in enumerate(self.layers): - mems_i = None if mems is None else mems[i] - if mems_i is not None and mlen > 0: - cur_emb = self.r_emb[i][:-qlen] - cur_size = cur_emb.size(0) - if cur_size < mlen: - cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1) - cur_emb = torch.cat([cur_emb_pad, cur_emb], 0) - else: - cur_emb = cur_emb[-mlen:] - mems_i += cur_emb.view(mlen, 1, -1) - core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1) - - core_out = layer(core_out, dec_attn_mask=dec_attn_mask, - mems=mems_i) - hids.append(core_out) - - core_out = self.drop(core_out) - - new_mems = self._update_mems(hids, mems, mlen, qlen) - - return core_out, new_mems - - def forward(self, data, target, *mems): - # nn.DataParallel does not allow size(0) tensors to be broadcasted. - # So, have to initialize size(0) mems inside the model forward. - # Moreover, have to return new_mems to allow nn.DataParallel to piece - # them together. - if not mems: mems = self.init_mems() - - tgt_len = target.size(0) - hidden, new_mems = self._forward(data, mems=mems) - - pred_hid = hidden[-tgt_len:] - if self.sample_softmax > 0 and self.training: - assert self.tie_weight - logit = sample_logits(self.word_emb, - self.out_layer.bias, target, pred_hid, self.sampler) - loss = -F.log_softmax(logit, -1)[:, :, 0] - else: - loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1)) - loss = loss.view(tgt_len, -1) - - if new_mems is None: - return [loss] - else: - return [loss] + new_mems - -if __name__ == '__main__': - import argparse - - parser = argparse.ArgumentParser(description='unit test') - - parser.add_argument('--n_layer', type=int, default=4, help='') - parser.add_argument('--n_rel_layer', type=int, default=4, help='') - parser.add_argument('--n_head', type=int, default=2, help='') - parser.add_argument('--d_head', type=int, default=2, help='') - parser.add_argument('--d_model', type=int, default=200, help='') - parser.add_argument('--d_embed', type=int, default=200, help='') - parser.add_argument('--d_inner', type=int, default=200, help='') - parser.add_argument('--dropout', type=float, default=0.0, help='') - parser.add_argument('--cuda', action='store_true', help='') - parser.add_argument('--seed', type=int, default=1111, help='') - parser.add_argument('--multi_gpu', action='store_true', help='') - - args = parser.parse_args() - - device = torch.device("cuda" if args.cuda else "cpu") - - B = 4 - tgt_len, mem_len, ext_len = 36, 36, 0 - data_len = tgt_len * 20 - args.n_token = 10000 - - import data_utils - - data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device) - diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len) - - cutoffs = [args.n_token // 2] - tie_projs = [False] + [True] * len(cutoffs) - - for div_val in [1, 2]: - for d_embed in [200, 100]: - model = MemTransformerLM(args.n_token, args.n_layer, args.n_head, - args.d_model, args.d_head, args.d_inner, args.dropout, - dropatt=args.dropout, tie_weight=True, - d_embed=d_embed, div_val=div_val, - tie_projs=tie_projs, pre_lnorm=True, - tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len, - cutoffs=cutoffs, attn_type=0).to(device) - - print(sum(p.numel() for p in model.parameters())) - - mems = tuple() - for idx, (inp, tgt, seqlen) in enumerate(diter): - print('batch {}'.format(idx)) - out = model(inp, tgt, *mems) - mems = out[1:] diff --git a/transformer-xl/pytorch/run_enwik8_base.sh b/transformer-xl/pytorch/run_enwik8_base.sh deleted file mode 100644 index db542a8..0000000 --- a/transformer-xl/pytorch/run_enwik8_base.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/enwik8/ \ - --dataset enwik8 \ - --n_layer 12 \ - --d_model 512 \ - --n_head 8 \ - --d_head 64 \ - --d_inner 2048 \ - --dropout 0.1 \ - --dropatt 0.0 \ - --optim adam \ - --lr 0.00025 \ - --warmup_step 0 \ - --max_step 400000 \ - --tgt_len 512 \ - --mem_len 512 \ - --eval_tgt_len 128 \ - --batch_size 22 \ - --multi_gpu \ - --gpu0_bsz 4 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/enwik8/ \ - --dataset enwik8 \ - --tgt_len 80 \ - --mem_len 2100 \ - --clamp_len 820 \ - --same_length \ - --split test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_enwik8_large.sh b/transformer-xl/pytorch/run_enwik8_large.sh deleted file mode 100644 index 5db67bf..0000000 --- a/transformer-xl/pytorch/run_enwik8_large.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/enwik8/ \ - --dataset enwik8 \ - --n_layer 24 \ - --d_model 1024 \ - --n_head 8 \ - --d_head 128 \ - --d_inner 3072 \ - --dropout 0.15 \ - --dropatt 0.15 \ - --optim adam \ - --lr 0.00025 \ - --warmup_step 4000 \ - --max_step 400000 \ - --tgt_len 768 \ - --mem_len 768 \ - --eval_tgt_len 128 \ - --batch_size 64 \ - --multi_gpu \ - --gpu0_bsz 0 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/enwik8/ \ - --dataset enwik8 \ - --tgt_len 128 \ - --mem_len 3800 \ - --clamp_len 1000 \ - --same_length \ - --split test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_lm1b_base.sh b/transformer-xl/pytorch/run_lm1b_base.sh deleted file mode 100644 index e4aebef..0000000 --- a/transformer-xl/pytorch/run_lm1b_base.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/one-billion-words/ \ - --dataset lm1b \ - --adaptive \ - --n_layer 18 \ - --d_model 1024 \ - --div_val 4 \ - --n_head 8 \ - --d_head 128 \ - --d_inner 4096 \ - --dropout 0.0 \ - --dropatt 0.0 \ - --optim adam \ - --warmup_step 20000 \ - --max_step 500000 \ - --lr 0.00025 \ - --tgt_len 32 \ - --mem_len 32 \ - --eval_tgt_len 32 \ - --batch_size 224 \ - --multi_gpu \ - --gpu0_bsz 32 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/one-billion-words/ \ - --dataset lm1b \ - --batch_size 64 \ - --tgt_len 32 \ - --mem_len 128 \ - --split test \ - --same_length \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_lm1b_large.sh b/transformer-xl/pytorch/run_lm1b_large.sh deleted file mode 100644 index f8b330a..0000000 --- a/transformer-xl/pytorch/run_lm1b_large.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/one-billion-words/ \ - --dataset lm1b \ - --adaptive \ - --div_val 4 \ - --n_layer 24 \ - --d_model 1280 \ - --n_head 16 \ - --d_head 80 \ - --d_inner 8192 \ - --dropout 0.05 \ - --dropatt 0.05 \ - --optim adam \ - --warmup_step 30000 \ - --max_step 1200000 \ - --lr 0.00025 \ - --tgt_len 32 \ - --mem_len 32 \ - --eval_tgt_len 32 \ - --batch_size 512 \ - --multi_gpu \ - --gpu0_bsz 0 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/one-billion-words/ \ - --dataset lm1b \ - --batch_size 8 \ - --tgt_len 32 \ - --mem_len 128 \ - --split test \ - --same_length \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_text8_base.sh b/transformer-xl/pytorch/run_text8_base.sh deleted file mode 100644 index 7058f77..0000000 --- a/transformer-xl/pytorch/run_text8_base.sh +++ /dev/null @@ -1,41 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/text8/ \ - --dataset text8 \ - --n_layer 12 \ - --d_model 512 \ - --n_head 8 \ - --d_head 64 \ - --d_inner 2048 \ - --dropout 0.1 \ - --dropatt 0.0 \ - --optim adam \ - --lr 0.00025 \ - --warmup_step 0 \ - --max_step 400000 \ - --tgt_len 512 \ - --mem_len 512 \ - --eval_tgt_len 128 \ - --batch_size 22 \ - --multi_gpu \ - --gpu0_bsz 4 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/text8/ \ - --dataset text8 \ - --tgt_len 80 \ - --mem_len 2100 \ - --clamp_len 820 \ - --same_length \ - --split test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_text8_large.sh b/transformer-xl/pytorch/run_text8_large.sh deleted file mode 100644 index cfc84df..0000000 --- a/transformer-xl/pytorch/run_text8_large.sh +++ /dev/null @@ -1,38 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/text8/ \ - --dataset text8 \ - --n_layer 24 \ - --d_model 1024 \ - --n_head 8 \ - --d_head 128 \ - --d_inner 3072 \ - --dropout 0.15 \ - --dropatt 0.15 \ - --optim adam \ - --lr 0.00025 \ - --tgt_len 768 \ - --mem_len 768 \ - --eval_tgt_len 128 \ - --batch_size 64 \ - --max_step 400000 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/text8/ \ - --dataset text8 \ - --tgt_len 128 \ - --mem_len 3800 \ - --clamp_len 1000 \ - --same_length \ - --split test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_wt103_base.sh b/transformer-xl/pytorch/run_wt103_base.sh deleted file mode 100644 index 22c7550..0000000 --- a/transformer-xl/pytorch/run_wt103_base.sh +++ /dev/null @@ -1,42 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/wikitext-103/ \ - --dataset wt103 \ - --adaptive \ - --n_layer 16 \ - --d_model 410 \ - --n_head 10 \ - --d_head 41 \ - --d_inner 2100 \ - --dropout 0.1 \ - --dropatt 0.0 \ - --optim adam \ - --lr 0.00025 \ - --warmup_step 0 \ - --max_step 200000 \ - --tgt_len 150 \ - --mem_len 150 \ - --eval_tgt_len 150 \ - --batch_size 60 \ - --multi_gpu \ - --gpu0_bsz 4 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/wikitext-103/ \ - --dataset wt103 \ - --tgt_len 64 \ - --mem_len 640 \ - --clamp_len 400 \ - --same_length \ - --split test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/run_wt103_large.sh b/transformer-xl/pytorch/run_wt103_large.sh deleted file mode 100644 index a4e701b..0000000 --- a/transformer-xl/pytorch/run_wt103_large.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/bin/bash - -if [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --cuda \ - --data ../data/wikitext-103/ \ - --dataset wt103 \ - --adaptive \ - --div_val 4 \ - --n_layer 18 \ - --d_model 1024 \ - --n_head 16 \ - --d_head 64 \ - --d_inner 4096 \ - --dropout 0.2 \ - --dropatt 0.2 \ - --optim adam \ - --lr 0.00025 \ - --warmup_step 16000 \ - --max_step 4000000 \ - --tgt_len 384 \ - --mem_len 384 \ - --eval_tgt_len 128 \ - --batch_size 128 \ - --multi_gpu \ - --gpu0_bsz 0 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python eval.py \ - --cuda \ - --data ../data/wikitext-103/ \ - --dataset wt103 \ - --tgt_len 128 \ - --mem_len 1600 \ - --clamp_len 1000 \ - --same_length \ - --split test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/pytorch/train.py b/transformer-xl/pytorch/train.py deleted file mode 100644 index 0e00e82..0000000 --- a/transformer-xl/pytorch/train.py +++ /dev/null @@ -1,562 +0,0 @@ -# coding: utf-8 -import argparse -import time -import math -import os, sys -import itertools - -import numpy as np - -import torch -import torch.nn as nn -import torch.optim as optim - -from data_utils import get_lm_corpus -from mem_transformer import MemTransformerLM -from utils.exp_utils import create_exp_dir -from utils.data_parallel import BalancedDataParallel - -parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') -parser.add_argument('--data', type=str, default='../data/wikitext-103', - help='location of the data corpus') -parser.add_argument('--dataset', type=str, default='wt103', - choices=['wt103', 'lm1b', 'enwik8', 'text8'], - help='dataset name') -parser.add_argument('--n_layer', type=int, default=12, - help='number of total layers') -parser.add_argument('--n_head', type=int, default=10, - help='number of heads') -parser.add_argument('--d_head', type=int, default=50, - help='head dimension') -parser.add_argument('--d_embed', type=int, default=-1, - help='embedding dimension') -parser.add_argument('--d_model', type=int, default=500, - help='model dimension') -parser.add_argument('--d_inner', type=int, default=1000, - help='inner dimension in FF') -parser.add_argument('--dropout', type=float, default=0.0, - help='global dropout rate') -parser.add_argument('--dropatt', type=float, default=0.0, - help='attention probability dropout rate') -parser.add_argument('--init', default='normal', type=str, - help='parameter initializer to use.') -parser.add_argument('--emb_init', default='normal', type=str, - help='parameter initializer to use.') -parser.add_argument('--init_range', type=float, default=0.1, - help='parameters initialized by U(-init_range, init_range)') -parser.add_argument('--emb_init_range', type=float, default=0.01, - help='parameters initialized by U(-init_range, init_range)') -parser.add_argument('--init_std', type=float, default=0.02, - help='parameters initialized by N(0, init_std)') -parser.add_argument('--proj_init_std', type=float, default=0.01, - help='parameters initialized by N(0, init_std)') -parser.add_argument('--optim', default='adam', type=str, - choices=['adam', 'sgd', 'adagrad'], - help='optimizer to use.') -parser.add_argument('--lr', type=float, default=0.00025, - help='initial learning rate (0.00025|5 for adam|sgd)') -parser.add_argument('--mom', type=float, default=0.0, - help='momentum for sgd') -parser.add_argument('--scheduler', default='cosine', type=str, - choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'], - help='lr scheduler to use.') -parser.add_argument('--warmup_step', type=int, default=0, - help='upper epoch limit') -parser.add_argument('--decay_rate', type=float, default=0.5, - help='decay factor when ReduceLROnPlateau is used') -parser.add_argument('--lr_min', type=float, default=0.0, - help='minimum learning rate during annealing') -parser.add_argument('--clip', type=float, default=0.25, - help='gradient clipping') -parser.add_argument('--clip_nonemb', action='store_true', - help='only clip the gradient of non-embedding params') -parser.add_argument('--max_step', type=int, default=100000, - help='upper epoch limit') -parser.add_argument('--batch_size', type=int, default=60, - help='batch size') -parser.add_argument('--batch_chunk', type=int, default=1, - help='split batch into chunks to save memory') -parser.add_argument('--tgt_len', type=int, default=70, - help='number of tokens to predict') -parser.add_argument('--eval_tgt_len', type=int, default=50, - help='number of tokens to predict for evaluation') -parser.add_argument('--ext_len', type=int, default=0, - help='length of the extended context') -parser.add_argument('--mem_len', type=int, default=0, - help='length of the retained previous heads') -parser.add_argument('--not_tied', action='store_true', - help='do not tie the word embedding and softmax weights') -parser.add_argument('--seed', type=int, default=1111, - help='random seed') -parser.add_argument('--cuda', action='store_true', - help='use CUDA') -parser.add_argument('--adaptive', action='store_true', - help='use adaptive softmax') -parser.add_argument('--div_val', type=int, default=1, - help='divident value for adapative input and softmax') -parser.add_argument('--pre_lnorm', action='store_true', - help='apply LayerNorm to the input instead of the output') -parser.add_argument('--varlen', action='store_true', - help='use variable length') -parser.add_argument('--multi_gpu', action='store_true', - help='use multiple GPU') -parser.add_argument('--log-interval', type=int, default=200, - help='report interval') -parser.add_argument('--eval-interval', type=int, default=4000, - help='evaluation interval') -parser.add_argument('--work_dir', default='LM-TFM', type=str, - help='experiment directory.') -parser.add_argument('--restart', action='store_true', - help='restart training from the saved checkpoint') -parser.add_argument('--restart_dir', type=str, default='', - help='restart dir') -parser.add_argument('--debug', action='store_true', - help='run in debug mode (do not create exp dir)') -parser.add_argument('--same_length', action='store_true', - help='use the same attn length for all tokens') -parser.add_argument('--attn_type', type=int, default=0, - help='attention type. 0 for ours, 1 for Shaw et al,' - '2 for Vaswani et al, 3 for Al Rfou et al.') -parser.add_argument('--clamp_len', type=int, default=-1, - help='use the same pos embeddings after clamp_len') -parser.add_argument('--eta_min', type=float, default=0.0, - help='min learning rate for cosine scheduler') -parser.add_argument('--gpu0_bsz', type=int, default=-1, - help='batch size on gpu 0') -parser.add_argument('--max_eval_steps', type=int, default=-1, - help='max eval steps') -parser.add_argument('--sample_softmax', type=int, default=-1, - help='number of samples in sampled softmax') -parser.add_argument('--patience', type=int, default=0, - help='patience') -parser.add_argument('--finetune_v2', action='store_true', - help='finetune v2') -parser.add_argument('--finetune_v3', action='store_true', - help='finetune v3') -parser.add_argument('--fp16', action='store_true', - help='Run in pseudo-fp16 mode (fp16 storage fp32 math).') -parser.add_argument('--static-loss-scale', type=float, default=1, - help='Static loss scale, positive power of 2 values can ' - 'improve fp16 convergence.') -parser.add_argument('--dynamic-loss-scale', action='store_true', - help='Use dynamic loss scaling. If supplied, this argument' - ' supersedes --static-loss-scale.') -args = parser.parse_args() -args.tied = not args.not_tied - -if args.d_embed < 0: - args.d_embed = args.d_model - -assert args.ext_len >= 0, 'extended context length must be non-negative' -assert args.batch_size % args.batch_chunk == 0 - -args.work_dir = '{}-{}'.format(args.work_dir, args.dataset) -args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S')) -logging = create_exp_dir(args.work_dir, - scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug) - -# Set the random seed manually for reproducibility. -np.random.seed(args.seed) -torch.manual_seed(args.seed) -if torch.cuda.is_available(): - if not args.cuda: - print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda') - else: - torch.cuda.manual_seed_all(args.seed) - -# Validate `--fp16` option -if args.fp16: - if not args.cuda: - print('WARNING: --fp16 requires --cuda, ignoring --fp16 option') - args.fp16 = False - else: - try: - from apex.fp16_utils import FP16_Optimizer - except: - print('WARNING: apex not installed, ignoring --fp16 option') - args.fp16 = False - -device = torch.device('cuda' if args.cuda else 'cpu') - -############################################################################### -# Load data -############################################################################### -corpus = get_lm_corpus(args.data, args.dataset) -ntokens = len(corpus.vocab) -args.n_token = ntokens - -eval_batch_size = 10 -tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len, - device=device, ext_len=args.ext_len) -va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len, - device=device, ext_len=args.ext_len) -te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len, - device=device, ext_len=args.ext_len) - -# adaptive softmax / embedding -cutoffs, tie_projs = [], [False] -if args.adaptive: - assert args.dataset in ['wt103', 'lm1b'] - if args.dataset == 'wt103': - cutoffs = [20000, 40000, 200000] - tie_projs += [True] * len(cutoffs) - elif args.dataset == 'lm1b': - cutoffs = [60000, 100000, 640000] - tie_projs += [False] * len(cutoffs) - -############################################################################### -# Build the model -############################################################################### -def init_weight(weight): - if args.init == 'uniform': - nn.init.uniform_(weight, -args.init_range, args.init_range) - elif args.init == 'normal': - nn.init.normal_(weight, 0.0, args.init_std) - -def init_bias(bias): - nn.init.constant_(bias, 0.0) - -def weights_init(m): - classname = m.__class__.__name__ - if classname.find('Linear') != -1: - if hasattr(m, 'weight') and m.weight is not None: - init_weight(m.weight) - if hasattr(m, 'bias') and m.bias is not None: - init_bias(m.bias) - elif classname.find('AdaptiveEmbedding') != -1: - if hasattr(m, 'emb_projs'): - for i in range(len(m.emb_projs)): - if m.emb_projs[i] is not None: - nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std) - elif classname.find('Embedding') != -1: - if hasattr(m, 'weight'): - init_weight(m.weight) - elif classname.find('ProjectedAdaptiveLogSoftmax') != -1: - if hasattr(m, 'cluster_weight') and m.cluster_weight is not None: - init_weight(m.cluster_weight) - if hasattr(m, 'cluster_bias') and m.cluster_bias is not None: - init_bias(m.cluster_bias) - if hasattr(m, 'out_projs'): - for i in range(len(m.out_projs)): - if m.out_projs[i] is not None: - nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std) - elif classname.find('LayerNorm') != -1: - if hasattr(m, 'weight'): - nn.init.normal_(m.weight, 1.0, args.init_std) - if hasattr(m, 'bias') and m.bias is not None: - init_bias(m.bias) - elif classname.find('TransformerLM') != -1: - if hasattr(m, 'r_emb'): - init_weight(m.r_emb) - if hasattr(m, 'r_w_bias'): - init_weight(m.r_w_bias) - if hasattr(m, 'r_r_bias'): - init_weight(m.r_r_bias) - if hasattr(m, 'r_bias'): - init_bias(m.r_bias) - -def update_dropout(m): - classname = m.__class__.__name__ - if classname.find('Dropout') != -1: - if hasattr(m, 'p'): - m.p = args.dropout - -def update_dropatt(m): - if hasattr(m, 'dropatt'): - m.dropatt.p = args.dropatt - -if args.restart: - with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f: - model = torch.load(f) - if not args.fp16: - model = model.float() - model.apply(update_dropout) - model.apply(update_dropatt) -else: - model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model, - args.d_head, args.d_inner, args.dropout, args.dropatt, - tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val, - tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len, - ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs, - same_length=args.same_length, attn_type=args.attn_type, - clamp_len=args.clamp_len, sample_softmax=args.sample_softmax) - model.apply(weights_init) - model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing -args.n_all_param = sum([p.nelement() for p in model.parameters()]) -args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()]) - -if args.fp16: - model = model.half() - -if args.multi_gpu: - model = model.to(device) - if args.gpu0_bsz >= 0: - para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk, - model, dim=1).to(device) - else: - para_model = nn.DataParallel(model, dim=1).to(device) -else: - para_model = model.to(device) - -#### optimizer -if args.optim.lower() == 'sgd': - if args.sample_softmax > 0: - dense_params, sparse_params = [], [] - for param in model.parameters(): - if param.size() == model.word_emb.weight.size(): - sparse_params.append(param) - else: - dense_params.append(param) - optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2) - optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom) - else: - optimizer = optim.SGD(model.parameters(), lr=args.lr, - momentum=args.mom) -elif args.optim.lower() == 'adam': - if args.sample_softmax > 0: - dense_params, sparse_params = [], [] - for param in model.parameters(): - if param.size() == model.word_emb.weight.size(): - sparse_params.append(param) - else: - dense_params.append(param) - optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr) - optimizer = optim.Adam(dense_params, lr=args.lr) - else: - optimizer = optim.Adam(model.parameters(), lr=args.lr) -elif args.optim.lower() == 'adagrad': - optimizer = optim.Adagrad(model.parameters(), lr=args.lr) - -#### scheduler -if args.scheduler == 'cosine': - # here we do not set eta_min to lr_min to be backward compatible - # because in previous versions eta_min is default to 0 - # rather than the default value of lr_min 1e-6 - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, - args.max_step, eta_min=args.eta_min) # should use eta_min arg - if args.sample_softmax > 0: - scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse, - args.max_step, eta_min=args.eta_min) # should use eta_min arg -elif args.scheduler == 'inv_sqrt': - # originally used for Transformer (in Attention is all you need) - def lr_lambda(step): - # return a multiplier instead of a learning rate - if step == 0 and args.warmup_step == 0: - return 1. - else: - return 1. / (step ** 0.5) if step > args.warmup_step \ - else step / (args.warmup_step ** 1.5) - scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) -elif args.scheduler == 'dev_perf': - scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, - factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) - if args.sample_softmax > 0: - scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse, - factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min) -elif args.scheduler == 'constant': - pass - -if args.cuda and args.fp16: - # If args.dynamic_loss_scale is False, static_loss_scale will be used. - # If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale. - optimizer = FP16_Optimizer(optimizer, - static_loss_scale = args.static_loss_scale, - dynamic_loss_scale = args.dynamic_loss_scale, - dynamic_loss_args = {'init_scale': 2 ** 16}) - -if args.restart: - if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')): - with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f: - opt_state_dict = torch.load(f) - optimizer.load_state_dict(opt_state_dict) - else: - print('Optimizer was not saved. Start from scratch.') - -logging('=' * 100) -for k, v in args.__dict__.items(): - logging(' - {} : {}'.format(k, v)) -logging('=' * 100) -logging('#params = {}'.format(args.n_all_param)) -logging('#non emb params = {}'.format(args.n_nonemb_param)) - -############################################################################### -# Training code -############################################################################### - -def evaluate(eval_iter): - # Turn on evaluation mode which disables dropout. - model.eval() - - # If the model does not use memory at all, make the ext_len longer. - # Otherwise, make the mem_len longer and keep the ext_len the same. - if args.mem_len == 0: - model.reset_length(args.eval_tgt_len, - args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len) - else: - model.reset_length(args.eval_tgt_len, - args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len) - - # Evaluation - total_len, total_loss = 0, 0. - with torch.no_grad(): - mems = tuple() - for i, (data, target, seq_len) in enumerate(eval_iter): - if args.max_eval_steps > 0 and i >= args.max_eval_steps: - break - ret = model(data, target, *mems) - loss, mems = ret[0], ret[1:] - loss = loss.mean() - total_loss += seq_len * loss.float().item() - total_len += seq_len - - # Switch back to the training mode - model.reset_length(args.tgt_len, args.ext_len, args.mem_len) - model.train() - - return total_loss / total_len - - -def train(): - # Turn on training mode which enables dropout. - global train_step, train_loss, best_val_loss, eval_start_time, log_start_time - model.train() - if args.batch_chunk > 1: - mems = [tuple() for _ in range(args.batch_chunk)] - else: - mems = tuple() - train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter - for batch, (data, target, seq_len) in enumerate(train_iter): - model.zero_grad() - if args.batch_chunk > 1: - data_chunks = torch.chunk(data, args.batch_chunk, 1) - target_chunks = torch.chunk(target, args.batch_chunk, 1) - for i in range(args.batch_chunk): - data_i = data_chunks[i].contiguous() - target_i = target_chunks[i].contiguous() - ret = para_model(data_i, target_i, *mems[i]) - loss, mems[i] = ret[0], ret[1:] - loss = loss.float().mean().type_as(loss) / args.batch_chunk - if args.fp16: - optimizer.backward(loss) - else: - loss.backward() - train_loss += loss.float().item() - else: - ret = para_model(data, target, *mems) - loss, mems = ret[0], ret[1:] - loss = loss.float().mean().type_as(loss) - if args.fp16: - optimizer.backward(loss) - else: - loss.backward() - train_loss += loss.float().item() - - if args.fp16: - optimizer.clip_master_grads(args.clip) - else: - torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) - - optimizer.step() - if args.sample_softmax > 0: - optimizer_sparse.step() - - # step-wise learning rate annealing - train_step += 1 - if args.scheduler in ['cosine', 'constant', 'dev_perf']: - # linear warmup stage - if train_step < args.warmup_step: - curr_lr = args.lr * train_step / args.warmup_step - optimizer.param_groups[0]['lr'] = curr_lr - if args.sample_softmax > 0: - optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2 - else: - if args.scheduler == 'cosine': - scheduler.step(train_step) - if args.sample_softmax > 0: - scheduler_sparse.step(train_step) - elif args.scheduler == 'inv_sqrt': - scheduler.step(train_step) - - if train_step % args.log_interval == 0: - cur_loss = train_loss / args.log_interval - elapsed = time.time() - log_start_time - log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \ - '| ms/batch {:5.2f} | loss {:5.2f}'.format( - epoch, train_step, batch+1, optimizer.param_groups[0]['lr'], - elapsed * 1000 / args.log_interval, cur_loss) - if args.dataset in ['enwik8', 'text8']: - log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2)) - else: - log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss)) - logging(log_str) - train_loss = 0 - log_start_time = time.time() - - if train_step % args.eval_interval == 0: - val_loss = evaluate(va_iter) - logging('-' * 100) - log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \ - '| valid loss {:5.2f}'.format( - train_step // args.eval_interval, train_step, - (time.time() - eval_start_time), val_loss) - if args.dataset in ['enwik8', 'text8']: - log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2)) - else: - log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss)) - logging(log_str) - logging('-' * 100) - # Save the model if the validation loss is the best we've seen so far. - if not best_val_loss or val_loss < best_val_loss: - if not args.debug: - with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f: - torch.save(model, f) - with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f: - torch.save(optimizer.state_dict(), f) - best_val_loss = val_loss - - # dev-performance based learning rate annealing - if args.scheduler == 'dev_perf': - scheduler.step(val_loss) - if args.sample_softmax > 0: - scheduler_sparse.step(val_loss) - - eval_start_time = time.time() - - if train_step == args.max_step: - break - -# Loop over epochs. -train_step = 0 -train_loss = 0 -best_val_loss = None - -log_start_time = time.time() -eval_start_time = time.time() - -# At any point you can hit Ctrl + C to break out of training early. -try: - for epoch in itertools.count(start=1): - train() - if train_step == args.max_step: - logging('-' * 100) - logging('End of training') - break -except KeyboardInterrupt: - logging('-' * 100) - logging('Exiting from training early') - -# Load the best saved model. -with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f: - model = torch.load(f) -para_model = model.to(device) - -# Run on test data. -test_loss = evaluate(te_iter) -logging('=' * 100) -if args.dataset in ['enwik8', 'text8']: - logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format( - test_loss, test_loss / math.log(2))) -else: - logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format( - test_loss, math.exp(test_loss))) -logging('=' * 100) diff --git a/transformer-xl/pytorch/utils/adaptive_softmax.py b/transformer-xl/pytorch/utils/adaptive_softmax.py deleted file mode 100644 index f22da23..0000000 --- a/transformer-xl/pytorch/utils/adaptive_softmax.py +++ /dev/null @@ -1,90 +0,0 @@ -from collections import defaultdict - -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -class AdaptiveLogSoftmax(nn.Module): - def __init__(self, in_features, n_classes, cutoffs, keep_order=False): - super(AdaptiveLogSoftmax, self).__init__() - - cutoffs = list(cutoffs) - - if (cutoffs != sorted(cutoffs)) \ - or (min(cutoffs) <= 0) \ - or (max(cutoffs) >= (n_classes - 1)) \ - or (len(set(cutoffs)) != len(cutoffs)) \ - or any([int(c) != c for c in cutoffs]): - - raise ValueError("cutoffs should be a sequence of unique, positive " - "integers sorted in an increasing order, where " - "each value is between 1 and n_classes-1") - - self.in_features = in_features - self.n_classes = n_classes - self.cutoffs = cutoffs + [n_classes] - - self.shortlist_size = self.cutoffs[0] - self.n_clusters = len(self.cutoffs) - 1 - self.head_size = self.shortlist_size + self.n_clusters - - self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features)) - self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) - - self.keep_order = keep_order - - - def forward(self, hidden, target, weight, bias, keep_order=False): - if hidden.size(0) != target.size(0): - raise RuntimeError('Input and target should have the same size ' - 'in the batch dimension.') - - head_weight = torch.cat( - [weight[:self.shortlist_size], self.cluster_weight], dim=0) - head_bias = torch.cat( - [bias[:self.shortlist_size], self.cluster_bias], dim=0) - - head_logit = F.linear(hidden, head_weight, bias=head_bias) - head_logprob = F.log_softmax(head_logit, dim=1) - - nll = torch.zeros_like(target, - dtype=hidden.dtype, device=hidden.DEVICE) - - offset = 0 - cutoff_values = [0] + self.cutoffs - for i in range(len(cutoff_values) - 1): - l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1] - - mask_i = (target >= l_idx) & (target < h_idx) - indices_i = mask_i.nonzero().squeeze() - - if indices_i.numel() == 0: - continue - - target_i = target.index_select(0, indices_i) - l_idx - head_logprob_i = head_logprob.index_select(0, indices_i) - - if i == 0: - logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) - else: - weight_i = weight[l_idx:h_idx] - bias_i = bias[l_idx:h_idx] - - hidden_i = hidden.index_select(0, indices_i) - - tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i) - tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) - - logprob_i = head_logprob_i[:, -i] \ - + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) - - if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: - nll.index_copy_(0, indices_i, -logprob_i) - else: - nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) - - offset += logprob_i.size(0) - - return nll diff --git a/transformer-xl/pytorch/utils/data_parallel.py b/transformer-xl/pytorch/utils/data_parallel.py deleted file mode 100644 index d7e1811..0000000 --- a/transformer-xl/pytorch/utils/data_parallel.py +++ /dev/null @@ -1,91 +0,0 @@ - -from torch.nn.parallel import DataParallel -import torch -from torch.nn.parallel._functions import Scatter -from torch.nn.parallel.parallel_apply import parallel_apply - -def scatter(inputs, target_gpus, chunk_sizes, dim=0): - r""" - Slices tensors into approximately equal chunks and - distributes them across given GPUs. Duplicates - references to objects that are not tensors. - """ - def scatter_map(obj): - if isinstance(obj, torch.Tensor): - try: - return Scatter.apply(target_gpus, chunk_sizes, dim, obj) - except: - print('obj', obj.size()) - print('dim', dim) - print('chunk_sizes', chunk_sizes) - quit() - if isinstance(obj, tuple) and len(obj) > 0: - return list(zip(*map(scatter_map, obj))) - if isinstance(obj, list) and len(obj) > 0: - return list(map(list, zip(*map(scatter_map, obj)))) - if isinstance(obj, dict) and len(obj) > 0: - return list(map(type(obj), zip(*map(scatter_map, obj.items())))) - return [obj for targets in target_gpus] - - # After scatter_map is called, a scatter_map cell will exist. This cell - # has a reference to the actual function scatter_map, which has references - # to a closure that has a reference to the scatter_map cell (because the - # fn is recursive). To avoid this reference cycle, we set the function to - # None, clearing the cell - try: - return scatter_map(inputs) - finally: - scatter_map = None - -def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0): - r"""Scatter with support for kwargs dictionary""" - inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else [] - kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else [] - if len(inputs) < len(kwargs): - inputs.extend([() for _ in range(len(kwargs) - len(inputs))]) - elif len(kwargs) < len(inputs): - kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))]) - inputs = tuple(inputs) - kwargs = tuple(kwargs) - return inputs, kwargs - -class BalancedDataParallel(DataParallel): - def __init__(self, gpu0_bsz, *args, **kwargs): - self.gpu0_bsz = gpu0_bsz - super().__init__(*args, **kwargs) - - def forward(self, *inputs, **kwargs): - if not self.device_ids: - return self.module(*inputs, **kwargs) - if self.gpu0_bsz == 0: - device_ids = self.device_ids[1:] - else: - device_ids = self.device_ids - inputs, kwargs = self.scatter(inputs, kwargs, device_ids) - if len(self.device_ids) == 1: - return self.module(*inputs[0], **kwargs[0]) - replicas = self.replicate(self.module, self.device_ids) - if self.gpu0_bsz == 0: - replicas = replicas[1:] - outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs) - return self.gather(outputs, self.output_device) - - def parallel_apply(self, replicas, device_ids, inputs, kwargs): - return parallel_apply(replicas, inputs, kwargs, device_ids) - - def scatter(self, inputs, kwargs, device_ids): - bsz = inputs[0].size(self.dim) - num_dev = len(self.device_ids) - gpu0_bsz = self.gpu0_bsz - bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1) - if gpu0_bsz < bsz_unit: - chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1) - delta = bsz - sum(chunk_sizes) - for i in range(delta): - chunk_sizes[i + 1] += 1 - if gpu0_bsz == 0: - chunk_sizes = chunk_sizes[1:] - else: - return super().scatter(inputs, kwargs, device_ids) - return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim) - diff --git a/transformer-xl/pytorch/utils/exp_utils.py b/transformer-xl/pytorch/utils/exp_utils.py deleted file mode 100644 index e44f7c2..0000000 --- a/transformer-xl/pytorch/utils/exp_utils.py +++ /dev/null @@ -1,40 +0,0 @@ -import functools -import os, shutil - -import numpy as np - -import torch - - -def logging(s, log_path, print_=True, log_=True): - if print_: - print(s) - if log_: - with open(log_path, 'a+') as f_log: - f_log.write(s + '\n') - -def get_logger(log_path, **kwargs): - return functools.partial(logging, log_path=log_path, **kwargs) - -def create_exp_dir(dir_path, scripts_to_save=None, debug=False): - if debug: - print('Debug Mode : no experiment dir created') - return functools.partial(logging, log_path=None, log_=False) - - if not os.path.exists(dir_path): - os.makedirs(dir_path) - - print('Experiment dir : {}'.format(dir_path)) - if scripts_to_save is not None: - script_path = os.path.join(dir_path, 'scripts') - if not os.path.exists(script_path): - os.makedirs(script_path) - for script in scripts_to_save: - dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script)) - shutil.copyfile(script, dst_file) - - return get_logger(log_path=os.path.join(dir_path, 'log.txt')) - -def save_checkpoint(model, optimizer, path, epoch): - torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch))) - torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch))) diff --git a/transformer-xl/pytorch/utils/log_uniform_sampler.py b/transformer-xl/pytorch/utils/log_uniform_sampler.py deleted file mode 100644 index 857ad52..0000000 --- a/transformer-xl/pytorch/utils/log_uniform_sampler.py +++ /dev/null @@ -1,147 +0,0 @@ -import torch -from torch import nn -import numpy as np - -class LogUniformSampler(object): - def __init__(self, range_max, n_sample): - """ - Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py - `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` - - expected count can be approximated by 1 - (1 - p)^n - and we use a numerically stable version -expm1(num_tries * log1p(-p)) - - Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run - """ - with torch.no_grad(): - self.range_max = range_max - log_indices = torch.arange(1., range_max+2., 1.).log_() - self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] - # print('P', self.dist.numpy().tolist()[-30:]) - - self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float() - - self.n_sample = n_sample - - def sample(self, labels): - """ - labels: [b1, b2] - Return - true_log_probs: [b1, b2] - samp_log_probs: [n_sample] - neg_samples: [n_sample] - """ - - # neg_samples = torch.empty(0).long() - n_sample = self.n_sample - n_tries = 2 * n_sample - - with torch.no_grad(): - neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique() - device = labels.DEVICE - neg_samples = neg_samples.to(device) - true_log_probs = self.log_q[labels].to(device) - samp_log_probs = self.log_q[neg_samples].to(device) - return true_log_probs, samp_log_probs, neg_samples - -def sample_logits(embedding, bias, labels, inputs, sampler): - """ - embedding: an nn.Embedding layer - bias: [n_vocab] - labels: [b1, b2] - inputs: [b1, b2, n_emb] - sampler: you may use a LogUniformSampler - Return - logits: [b1, b2, 1 + n_sample] - """ - true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels) - n_sample = neg_samples.size(0) - b1, b2 = labels.size(0), labels.size(1) - all_ids = torch.cat([labels.view(-1), neg_samples]) - all_w = embedding(all_ids) - true_w = all_w[: -n_sample].view(b1, b2, -1) - sample_w = all_w[- n_sample:].view(n_sample, -1) - - all_b = bias[all_ids] - true_b = all_b[: -n_sample].view(b1, b2) - sample_b = all_b[- n_sample:] - - hit = (labels[:, :, None] == neg_samples).detach() - - true_logits = torch.einsum('ijk,ijk->ij', - [true_w, inputs]) + true_b - true_log_probs - sample_logits = torch.einsum('lk,ijk->ijl', - [sample_w, inputs]) + sample_b - samp_log_probs - sample_logits.masked_fill_(hit, -1e30) - logits = torch.cat([true_logits[:, :, None], sample_logits], -1) - - return logits - - -# class LogUniformSampler(object): -# def __init__(self, range_max, unique=False): -# """ -# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py -# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)` -# """ -# self.range_max = range_max -# log_indices = torch.arange(1., range_max+2., 1.).log_() -# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1] - -# self.unique = unique - -# if self.unique: -# self.exclude_mask = torch.ByteTensor(range_max).fill_(0) - -# def sample(self, n_sample, labels): -# pos_sample, new_labels = labels.unique(return_inverse=True) -# n_pos_sample = pos_sample.size(0) -# n_neg_sample = n_sample - n_pos_sample - -# if self.unique: -# self.exclude_mask.index_fill_(0, pos_sample, 1) -# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0) -# self.exclude_mask.index_fill_(0, pos_sample, 0) -# else: -# sample_dist = self.dist - -# neg_sample = torch.multinomial(sample_dist, n_neg_sample) - -# sample = torch.cat([pos_sample, neg_sample]) -# sample_prob = self.dist[sample] - -# return new_labels, sample, sample_prob - - -if __name__ == '__main__': - S, B = 3, 4 - n_vocab = 10000 - n_sample = 5 - H = 32 - - labels = torch.LongTensor(S, B).random_(0, n_vocab) - - # sampler = LogUniformSampler(n_vocab, unique=False) - # new_labels, sample, sample_prob = sampler.sample(n_sample, labels) - - sampler = LogUniformSampler(n_vocab, unique=True) - # true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels) - - # print('true_probs', true_probs.numpy().tolist()) - # print('samp_probs', samp_probs.numpy().tolist()) - # print('neg_samples', neg_samples.numpy().tolist()) - - # print('sum', torch.sum(sampler.dist).item()) - - # assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item() - - embedding = nn.Embedding(n_vocab, H) - bias = torch.zeros(n_vocab) - inputs = torch.Tensor(S, B, H).normal_() - - logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample) - print('logits', logits.detach().numpy().tolist()) - print('logits shape', logits.size()) - print('out_labels', out_labels.detach().numpy().tolist()) - print('out_labels shape', out_labels.size()) - diff --git a/transformer-xl/pytorch/utils/proj_adaptive_softmax.py b/transformer-xl/pytorch/utils/proj_adaptive_softmax.py deleted file mode 100644 index c5a0f84..0000000 --- a/transformer-xl/pytorch/utils/proj_adaptive_softmax.py +++ /dev/null @@ -1,151 +0,0 @@ -from collections import defaultdict - -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -CUDA_MAJOR = int(torch.version.cuda.split('.')[0]) -CUDA_MINOR = int(torch.version.cuda.split('.')[1]) - -class ProjectedAdaptiveLogSoftmax(nn.Module): - def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, - keep_order=False): - super(ProjectedAdaptiveLogSoftmax, self).__init__() - - self.n_token = n_token - self.d_embed = d_embed - self.d_proj = d_proj - - self.cutoffs = cutoffs + [n_token] - self.cutoff_ends = [0] + self.cutoffs - self.div_val = div_val - - self.shortlist_size = self.cutoffs[0] - self.n_clusters = len(self.cutoffs) - 1 - self.head_size = self.shortlist_size + self.n_clusters - - if self.n_clusters > 0: - self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) - self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) - - self.out_layers = nn.ModuleList() - self.out_projs = nn.ParameterList() - - if div_val == 1: - for i in range(len(self.cutoffs)): - if d_proj != d_embed: - self.out_projs.append( - nn.Parameter(torch.Tensor(d_proj, d_embed)) - ) - else: - self.out_projs.append(None) - - self.out_layers.append(nn.Linear(d_embed, n_token)) - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] - d_emb_i = d_embed // (div_val ** i) - - self.out_projs.append( - nn.Parameter(torch.Tensor(d_proj, d_emb_i)) - ) - - self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx)) - - self.keep_order = keep_order - - def _compute_logit(self, hidden, weight, bias, proj): - if proj is None: - logit = F.linear(hidden, weight, bias=bias) - else: - # if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1: - proj_hid = F.linear(hidden, proj.t().contiguous()) - logit = F.linear(proj_hid, weight, bias=bias) - # else: - # logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) - # if bias is not None: - # logit = logit + bias - - return logit - - def forward(self, hidden, target, keep_order=False): - ''' - hidden :: [len*bsz x d_proj] - target :: [len*bsz] - ''' - - if hidden.size(0) != target.size(0): - raise RuntimeError('Input and target should have the same size ' - 'in the batch dimension.') - - if self.n_clusters == 0: - logit = self._compute_logit(hidden, self.out_layers[0].weight, - self.out_layers[0].bias, self.out_projs[0]) - nll = -F.log_softmax(logit, dim=-1) \ - .gather(1, target.unsqueeze(1)).squeeze(1) - else: - # construct weights and biases - weights, biases = [], [] - for i in range(len(self.cutoffs)): - if self.div_val == 1: - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - weight_i = self.out_layers[0].weight[l_idx:r_idx] - bias_i = self.out_layers[0].bias[l_idx:r_idx] - else: - weight_i = self.out_layers[i].weight - bias_i = self.out_layers[i].bias - - if i == 0: - weight_i = torch.cat( - [weight_i, self.cluster_weight], dim=0) - bias_i = torch.cat( - [bias_i, self.cluster_bias], dim=0) - - weights.append(weight_i) - biases.append(bias_i) - - head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0] - - head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) - head_logprob = F.log_softmax(head_logit, dim=1) - - nll = torch.zeros_like(target, - dtype=hidden.dtype, device=hidden.DEVICE) - - offset = 0 - cutoff_values = [0] + self.cutoffs - for i in range(len(cutoff_values) - 1): - l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] - - mask_i = (target >= l_idx) & (target < r_idx) - indices_i = mask_i.nonzero().squeeze() - - if indices_i.numel() == 0: - continue - - target_i = target.index_select(0, indices_i) - l_idx - head_logprob_i = head_logprob.index_select(0, indices_i) - - if i == 0: - logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1) - else: - weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i] - - hidden_i = hidden.index_select(0, indices_i) - - tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) - tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) - - logprob_i = head_logprob_i[:, -i] \ - + tail_logprob_i.gather(1, target_i[:,None]).squeeze(1) - - if (hasattr(self, 'keep_order') and self.keep_order) or keep_order: - nll.index_copy_(0, indices_i, -logprob_i) - else: - nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) - - offset += logprob_i.size(0) - - return nll diff --git a/transformer-xl/pytorch/utils/vocabulary.py b/transformer-xl/pytorch/utils/vocabulary.py deleted file mode 100644 index b6b8249..0000000 --- a/transformer-xl/pytorch/utils/vocabulary.py +++ /dev/null @@ -1,163 +0,0 @@ -import os -from collections import Counter, OrderedDict - -import torch - -class Vocab(object): - def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, - delimiter=None, vocab_file=None): - self.counter = Counter() - self.special = special - self.min_freq = min_freq - self.max_size = max_size - self.lower_case = lower_case - self.delimiter = delimiter - self.vocab_file = vocab_file - - def tokenize(self, line, add_eos=False, add_double_eos=False): - line = line.strip() - # convert to lower case - if self.lower_case: - line = line.lower() - - # empty delimiter '' will evaluate False - if self.delimiter == '': - symbols = line - else: - symbols = line.split(self.delimiter) - - if add_double_eos: # lm1b - return [''] + symbols + [''] - elif add_eos: - return symbols + [''] - else: - return symbols - - def count_file(self, path, verbose=False, add_eos=False): - if verbose: print('counting file {} ...'.format(path)) - assert os.path.exists(path) - - sents = [] - with open(path, 'r', encoding='utf-8') as f: - for idx, line in enumerate(f): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - symbols = self.tokenize(line, add_eos=add_eos) - self.counter.update(symbols) - sents.append(symbols) - - return sents - - def count_sents(self, sents, verbose=False): - """ - sents : a list of sentences, each a list of tokenized symbols - """ - if verbose: print('counting {} sents ...'.format(len(sents))) - for idx, symbols in enumerate(sents): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - self.counter.update(symbols) - - def _build_from_file(self, vocab_file): - self.idx2sym = [] - self.sym2idx = OrderedDict() - - with open(vocab_file, 'r', encoding='utf-8') as f: - for line in f: - symb = line.strip().split()[0] - self.add_symbol(symb) - self.unk_idx = self.sym2idx[''] - - def build_vocab(self): - if self.vocab_file: - print('building vocab from {}'.format(self.vocab_file)) - self._build_from_file(self.vocab_file) - print('final vocab size {}'.format(len(self))) - else: - print('building vocab with min_freq={}, max_size={}'.format( - self.min_freq, self.max_size)) - self.idx2sym = [] - self.sym2idx = OrderedDict() - - for sym in self.special: - self.add_special(sym) - - for sym, cnt in self.counter.most_common(self.max_size): - if cnt < self.min_freq: break - self.add_symbol(sym) - - print('final vocab size {} from {} unique tokens'.format( - len(self), len(self.counter))) - - def encode_file(self, path, ordered=False, verbose=False, add_eos=True, - add_double_eos=False): - if verbose: print('encoding file {} ...'.format(path)) - assert os.path.exists(path) - encoded = [] - with open(path, 'r', encoding='utf-8') as f: - for idx, line in enumerate(f): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - symbols = self.tokenize(line, add_eos=add_eos, - add_double_eos=add_double_eos) - encoded.append(self.convert_to_tensor(symbols)) - - if ordered: - encoded = torch.cat(encoded) - - return encoded - - def encode_sents(self, sents, ordered=False, verbose=False): - if verbose: print('encoding {} sents ...'.format(len(sents))) - encoded = [] - for idx, symbols in enumerate(sents): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - encoded.append(self.convert_to_tensor(symbols)) - - if ordered: - encoded = torch.cat(encoded) - - return encoded - - def add_special(self, sym): - if sym not in self.sym2idx: - self.idx2sym.append(sym) - self.sym2idx[sym] = len(self.idx2sym) - 1 - setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) - - def add_symbol(self, sym): - if sym not in self.sym2idx: - self.idx2sym.append(sym) - self.sym2idx[sym] = len(self.idx2sym) - 1 - - def get_sym(self, idx): - assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) - return self.idx2sym[idx] - - def get_idx(self, sym): - if sym in self.sym2idx: - return self.sym2idx[sym] - else: - # print('encounter unk {}'.format(sym)) - assert '' not in sym - assert hasattr(self, 'unk_idx') - return self.sym2idx.get(sym, self.unk_idx) - - def get_symbols(self, indices): - return [self.get_sym(idx) for idx in indices] - - def get_indices(self, symbols): - return [self.get_idx(sym) for sym in symbols] - - def convert_to_tensor(self, symbols): - return torch.LongTensor(self.get_indices(symbols)) - - def convert_to_sent(self, indices, exclude=None): - if exclude is None: - return ' '.join([self.get_sym(idx) for idx in indices]) - else: - return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) - - def __len__(self): - return len(self.idx2sym) diff --git a/transformer-xl/tf/README.md b/transformer-xl/tf/README.md deleted file mode 100644 index 1cd82a0..0000000 --- a/transformer-xl/tf/README.md +++ /dev/null @@ -1,131 +0,0 @@ - -## Introduction - -This directory contains our TF implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our gpu codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts: -- `*large_tpu.sh` are for the SoTA setting on TPUs. These are exactly the commands we used to obtained our best results. -- `*base_gpu.sh` are for the base models which can be run on a few GPUs. - - -## Prerequisite - -- Python 2.7 -- Tensorflow [1.12.0](https://github.com/tensorflow/tensorflow/releases/tag/v1.12.0) - - - -## Obtain and evaluate pretrained SoTA models - -#### 1. Download preprocessed data (vocab) & pretrained models - -(a) Set your own `DATA_ROOT` in `sota/download.sh` (default to `./`), which will be the root diretory of downloaded model. - -(b) Then, download the model & data by `bash sota/download.sh`. After downloading, the expected directory structure is as follows - -```markdown -pretrained_xl - tf_enwik8/ - data/ - cache.pkl - corpus-info.json - model/ - checkpoint - model.ckpt* - tf_wt103/ - ... - ... -``` - -**Note**: we include preprocessed data in the download files to make sure the **same vocabulary** is used. Please see the code `tf/data_utils.py` to understand the data structure. - - - -#### 2. Run evaluation scripts to replicate SoTA results on GPUs - -- **enwik8**: modify the script `sota/enwik8.sh` accordingly (see below) - - set `DATA_ROOT` to the same folder used in the download step (default to `./`) - - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 2 GPUs => about 60 mins - - run the script: `bash sota/enwik8.sh` - -- **lm1b**: modify the script `sota/lm1b.sh` accordingly (see below) - - set `DATA_ROOT` to the same folder used in the download step (default to `./`) - - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 1 GPUs => less than 5 mins - - run the script: `bash sota/lm1b.sh` - -- **wt103**: modify the script `sota/wt103.sh` accordingly (see below) - - set `DATA_ROOT` to the same folder used in the download step (default to `./`) - - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 1 GPUs => less than 5 mins - - run the script: `bash sota/wt103.sh` - -- **text8**: modify the script `sota/text8.sh` accordingly (see below) - - set `DATA_ROOT` to the same folder used in the download step (default to `./`) - - set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 2 GPUs => about 60 mins - - run the script: `bash sota/text8.sh` - - -#### 3. Resources Needed for SoTA Model Training - -We used 32, 32, 64, and 512 TPU cores for training our best models on enwik8, text8, wt103, and lm1b respectively. The training time for each model ranges from 2 to 5 days. - - - -## Train "Transformer-XL" from scratch with GPUs or TPUs - -### 1. Download raw data - -`bash getdata.sh` - - - -### 2. Preprocess, training and evaluation - -For `dataset` in `[enwik8, lm1b, wt103, text8]`: - -- check out `scripts/dataset_base_gpu.sh` for GPU training and evaluation -- check out `scripts/dataset_large_tpu.sh` for TPU training and evaluation - - - -#### (1) Preprocess raw data and create tfrecords - -**NOTE**: The preprocessing for GPU and TPU are different. So, you have to run them separately. - -GPU: - -- create training and validation data: `bash scripts/dataset_bas_gpu.sh train_data` -- create test data: `bash scripts/dataset_base_gpu.sh test_data` - -TPU: - -- Set the Google storage URL in `scripts/dataset_large_tpu.sh`: - - `GSDATA`: data URL - - `GSEXP`: experiment URL -- create training and validation data: `bash scripts/dataset_large_tpu.sh train_data` -- create test data: `bash scripts/dataset_large_tpu.sh test_data` - - - -#### (2) Run training - -Base models on GPUs: - -- Modify the configurations in `scripts/dataset_base_gpu.sh` according to your needs. -- `bash scripts/dataset_base_gpu.sh train` -- If enough resources are available, increasing the model sizes (e.g., `N_LAYER`, `D_MODEL`, `D_EMBED`, `D_HEAD`, `D_INNER`) so that they are closer to the values defined in `scripts/dataset_large_tpu.sh`. Likewise, when resources are limited, decrease the model sizes. It is recommended to ensure that `D_MODEL == D_EMBED` and `D_MODEL == N_HEAD x D_HEAD`. When the model sizes increase, remember to increase `warmup_steps` accordingly to alleviate optimization difficulties. -- Adjust the `NUM_CORE` parameter to reflect the number of GPUs to use. - -Larger models on TPUs: - -- Modify the configurations in `scripts/dataset_large_tpu.sh` according to your needs. -- `bash scripts/dataset_large_tpu.sh train` - - - -#### (3) Run evaluation - -Base models on GPUs: - -- `bash scripts/dataset_base_gpu.sh eval --eval_ckpt_path PATH_TO_CKPT` - -Larger models on TPUs: - -- `bash scripts/dataset_base_tpu.sh eval --eval_ckpt_path PATH_TO_CKPT` diff --git a/transformer-xl/tf/avg_checkpoints.py b/transformer-xl/tf/avg_checkpoints.py deleted file mode 100644 index ffa71b6..0000000 --- a/transformer-xl/tf/avg_checkpoints.py +++ /dev/null @@ -1,118 +0,0 @@ -# coding=utf-8 -# Copyright 2018 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Script to average values of variables in a list of checkpoint files.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import numpy as np -import six -from six.moves import zip # pylint: disable=redefined-builtin -import tensorflow as tf - -flags = tf.flags -FLAGS = flags.FLAGS - -flags.DEFINE_string("checkpoints", "", - "Comma-separated list of checkpoints to average.") -flags.DEFINE_integer("num_last_checkpoints", 0, - "Averages the last N saved checkpoints." - " If the checkpoints flag is set, this is ignored.") -flags.DEFINE_string("prefix", "", - "Prefix (e.g., directory) to append to each checkpoint.") -flags.DEFINE_string("output_path", "/tmp/averaged.ckpt", - "Path to output the averaged checkpoint to.") - - -def checkpoint_exists(path): - return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or - tf.gfile.Exists(path + ".index")) - - -def main(_): - tf.logging.set_verbosity(tf.logging.INFO) - if FLAGS.checkpoints: - # Get the checkpoints list from flags and run some basic checks. - checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")] - checkpoints = [c for c in checkpoints if c] - if not checkpoints: - raise ValueError("No checkpoints provided for averaging.") - if FLAGS.prefix: - checkpoints = [FLAGS.prefix + c for c in checkpoints] - else: - assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model" - assert FLAGS.prefix, ("Prefix must be provided when averaging last" - " N checkpoints") - checkpoint_state = tf.train.get_checkpoint_state( - os.path.dirname(FLAGS.prefix)) - # Checkpoints are ordered from oldest to newest. - checkpoints = checkpoint_state.all_model_checkpoint_paths[ - -FLAGS.num_last_checkpoints:] - - checkpoints = [c for c in checkpoints if checkpoint_exists(c)] - if not checkpoints: - if FLAGS.checkpoints: - raise ValueError( - "None of the provided checkpoints exist. %s" % FLAGS.checkpoints) - else: - raise ValueError("Could not find checkpoints at %s" % - os.path.dirname(FLAGS.prefix)) - - # Read variables from all checkpoints and average them. - tf.logging.info("Reading variables and averaging checkpoints:") - for c in checkpoints: - tf.logging.info("%s ", c) - var_list = tf.contrib.framework.list_variables(checkpoints[0]) - var_values, var_dtypes = {}, {} - for (name, shape) in var_list: - if not name.startswith("global_step"): - var_values[name] = np.zeros(shape) - for checkpoint in checkpoints: - reader = tf.contrib.framework.load_checkpoint(checkpoint) - for name in var_values: - tensor = reader.get_tensor(name) - var_dtypes[name] = tensor.dtype - var_values[name] += tensor - tf.logging.info("Read from checkpoint %s", checkpoint) - for name in var_values: # Average. - var_values[name] /= len(checkpoints) - - with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): - tf_vars = [ - tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v]) - for v in var_values - ] - placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars] - assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)] - global_step = tf.Variable( - 0, name="global_step", trainable=False, dtype=tf.int64) - saver = tf.train.Saver(tf.all_variables()) - - # Build a model consisting only of variables, set them to the average values. - with tf.Session() as sess: - sess.run(tf.initialize_all_variables()) - for p, assign_op, (name, value) in zip(placeholders, assign_ops, - six.iteritems(var_values)): - sess.run(assign_op, {p: value}) - # Use the built saver to save the averaged checkpoint. - saver.save(sess, FLAGS.output_path, global_step=global_step) - - tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path) - - -if __name__ == "__main__": - tf.app.run() diff --git a/transformer-xl/tf/data_utils.py b/transformer-xl/tf/data_utils.py deleted file mode 100644 index ea2e32b..0000000 --- a/transformer-xl/tf/data_utils.py +++ /dev/null @@ -1,586 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math -import os -from functools import partial - -from collections import Counter, OrderedDict -import pickle -import json -import multiprocessing as mp - -import numpy as np - -from absl import flags -import tensorflow as tf -from vocabulary import Vocab - -from tensorflow.gfile import Exists as exists -from tensorflow.gfile import MakeDirs as makedirs -from tensorflow.gfile import Glob as glob - - -def _preprocess(shard, train, vocab, save_dir, cutoffs, bin_sizes, bsz, tgt_len, - num_core_per_host, use_tpu, num_shuffle): - file_names = [] - num_batch = 0 - - path = train[shard] - data_shard = vocab.encode_file(path, ordered=False, add_double_eos=True) - - for shuffle in range(num_shuffle): - basename = "train-{:03d}-{:02d}".format(shard, shuffle) - print("Processing shard {} shuffle {}".format(shard, shuffle)) - - np.random.shuffle(data_shard) - file_name, num_batch_shuffle = create_ordered_tfrecords( - save_dir, basename, np.concatenate(data_shard), bsz, tgt_len, - num_core_per_host, cutoffs, bin_sizes, use_tpu=use_tpu) - file_names.append(file_name) - num_batch += num_batch_shuffle - - return file_names, num_batch - - -class Corpus(object): - def __init__(self, path, dataset, *args, **kwargs): - self.dataset = dataset - self.vocab = Vocab(*args, **kwargs) - - if self.dataset in ["ptb", "wt2", "enwik8", "text8"]: - self.vocab.count_file(os.path.join(path, "train.txt")) - self.vocab.count_file(os.path.join(path, "valid.txt")) - self.vocab.count_file(os.path.join(path, "test.txt")) - elif self.dataset == "wt103": - self.vocab.count_file(os.path.join(path, "train.txt")) - elif self.dataset == "lm1b": - train_path_pattern = os.path.join( - path, "1-billion-word-language-modeling-benchmark-r13output", - "training-monolingual.tokenized.shuffled", "news.en-*") - train_paths = glob(train_path_pattern) - - # the vocab will load from file when build_vocab() is called - # for train_path in sorted(train_paths): - # self.vocab.count_file(train_path, verbose=True) - - self.vocab.build_vocab() - - if self.dataset in ["ptb", "wt2", "wt103"]: - self.train = self.vocab.encode_file( - os.path.join(path, "train.txt"), ordered=True) - self.valid = self.vocab.encode_file( - os.path.join(path, "valid.txt"), ordered=True) - self.test = self.vocab.encode_file( - os.path.join(path, "test.txt"), ordered=True) - elif self.dataset in ["enwik8", "text8"]: - self.train = self.vocab.encode_file( - os.path.join(path, "train.txt"), ordered=True, add_eos=False) - self.valid = self.vocab.encode_file( - os.path.join(path, "valid.txt"), ordered=True, add_eos=False) - self.test = self.vocab.encode_file( - os.path.join(path, "test.txt"), ordered=True, add_eos=False) - elif self.dataset == "lm1b": - self.train = train_paths - valid_path = os.path.join(path, "valid.txt") - test_path = valid_path - self.valid = self.vocab.encode_file( - valid_path, ordered=True, add_double_eos=True) - self.test = self.vocab.encode_file( - test_path, ordered=True, add_double_eos=True) - - if self.dataset == "wt103": - self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)] - elif self.dataset == "lm1b": - self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)] - else: - self.cutoffs = [] - - - def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len, - num_core_per_host, **kwargs): - FLAGS = kwargs.get('FLAGS') - - file_names = [] - use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1) - - if use_tpu: - record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format( - split, bsz, tgt_len, num_core_per_host) - else: - record_name = "record_info-{}.bsz-{}.tlen-{}.json".format( - split, bsz, tgt_len) - - record_info_path = os.path.join(save_dir, record_name) - - if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]: - data = getattr(self, split) - bin_sizes = get_bin_sizes( - data, bsz // num_core_per_host, tgt_len, self.cutoffs) - file_name, num_batch = create_ordered_tfrecords( - save_dir, split, data, bsz, tgt_len, num_core_per_host, - self.cutoffs, bin_sizes, - num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1, - use_tpu=use_tpu) - file_names.append(file_name) - elif self.dataset == "lm1b": - bin_sizes = get_bin_sizes( - self.valid, bsz // num_core_per_host, tgt_len, self.cutoffs) - if split == "train": - np.random.seed(123456) - num_batch = 0 - - if FLAGS.num_procs > 1: - _preprocess_wrapper = partial(_preprocess, - train=self.train, vocab=self.vocab, save_dir=save_dir, - cutoffs=self.cutoffs, bin_sizes=bin_sizes, bsz=bsz, - tgt_len=tgt_len, num_core_per_host=num_core_per_host, - use_tpu=use_tpu, num_shuffle=FLAGS.num_shuffle) - - pool = mp.Pool(processes=FLAGS.num_procs) - results = pool.map(_preprocess_wrapper, range(len(self.train))) - for res in results: - file_names.extend(res[0]) - num_batch += res[1] - else: - for shard, path in enumerate(self.train): - data_shard = self.vocab.encode_file(path, ordered=False, - add_double_eos=True) - - num_shuffle = FLAGS.num_shuffle - - for shuffle in range(num_shuffle): - print("Processing shard {} shuffle {}".format(shard, shuffle)) - basename = "train-{:03d}-{:02d}".format(shard, shuffle) - np.random.shuffle(data_shard) - file_name, num_batch_ = create_ordered_tfrecords( - save_dir, basename, np.concatenate(data_shard), bsz, tgt_len, - num_core_per_host, - self.cutoffs, bin_sizes, use_tpu=use_tpu) - file_names.append(file_name) - num_batch += num_batch_ - - else: - file_name, num_batch = create_ordered_tfrecords( - save_dir, split, getattr(self, split), bsz, tgt_len, - num_core_per_host, - self.cutoffs, bin_sizes, use_tpu=use_tpu) - file_names.append(file_name) - - with open(record_info_path, "w") as fp: - record_info = { - "filenames": file_names, - "bin_sizes": bin_sizes, - "num_batch": num_batch - } - json.dump(record_info, fp) - - -def get_bin_sizes(data, batch_size, tgt_len, cutoffs, std_mult=[2.5, 2.5, 2.5]): - """ - Note: the `batch_size` here should be per-core batch size - """ - bin_sizes = [] - - def _nearest_to_eight(x): # so that it's faster on TPUs - y = x - x % 8 - return y + 8 if x % 8 >= 4 else max(8, y) - - if cutoffs: - num_batch = len(data) // batch_size // tgt_len - - data = data[:batch_size * num_batch * tgt_len] - data = data.reshape(batch_size, num_batch, tgt_len) - - tot = batch_size * tgt_len - for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])): - mask = (data >= left) * (data < right) - percents = mask.astype(np.float64).sum(2).sum(0) / tot - mean = np.mean(percents) - std = np.std(percents) - - bin_size = int(math.ceil(tgt_len * batch_size * (mean + std_mult[b] * std))) - bin_size = _nearest_to_eight(bin_size) - bin_sizes.append(bin_size) - - return bin_sizes - - -def _int64_feature(values): - return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) - -def _float_feature(values): - return tf.train.Feature(float_list=tf.train.FloatList(value=values)) - -def batchify(data, batch_size, num_passes): - """ - if use_tpu = True: num_passes > 1 - - Since TPU training requires entire [bsz x tgt_len] chunks, it can discard - as many as `bsz * tgt_len` tokens in training. When `bsz` and `tgt_len` are - both large, as in the case of TPU training for Transformer-XL, the problem - may lead to detectable performance drop. - - Here, we use multiple randomly shifted copies to deal with this problem. - """ - if num_passes > 1: - data_len = len(data) - double_data = np.concatenate([data, data]) - data_list = [] - for i in range(num_passes): - start = np.random.randint(0, data_len) - data_list.append(double_data[start:start+data_len]) - data = np.concatenate(data_list) - - num_step = len(data) // batch_size - data = data[:batch_size * num_step] - data = data.reshape(batch_size, num_step) - - return data - - -def create_ordered_tfrecords(save_dir, basename, data, batch_size, tgt_len, - num_core_per_host, cutoffs=[], bin_sizes=[], - num_passes=1, use_tpu=False): - - if use_tpu: - file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format( - basename, batch_size, tgt_len, num_core_per_host) - else: - file_name = "{}.bsz-{}.tlen-{}.tfrecords".format( - basename, batch_size, tgt_len) - - save_path = os.path.join(save_dir, file_name) - record_writer = tf.python_io.TFRecordWriter(save_path) - - batched_data = batchify(data, batch_size, num_passes) - - num_batch = 0 - # for t in range(0, batched_data.shape[1] - tgt_len - 1, tgt_len): - for t in range(0, batched_data.shape[1] - 1, tgt_len): - cur_tgt_len = min(batched_data.shape[1] - 1 - t, tgt_len) - # drop the remainder if use tpu - if use_tpu and cur_tgt_len < tgt_len: - break - if num_batch % 500 == 0: - print(" processing batch {}".format(num_batch)) - for idx in range(batch_size): - inputs = batched_data[idx, t:t + cur_tgt_len] - labels = batched_data[idx, t + 1:t + cur_tgt_len + 1] - - # features dict - feature = { - "inputs": _int64_feature(inputs), - "labels": _int64_feature(labels), - } - - if len(cutoffs) > 0 and use_tpu: - # validate `bin_sizes` and `cutoffs` - assert len(cutoffs) - len(bin_sizes) == 2, \ - "len(cutoffs) - len(bin_sizes) != 2" - - # mask for bin 0 - left, right = cutoffs[:2] - inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32) - tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32) - - feature["inp_mask"] = _float_feature(inp_mask) - feature["tgt_mask"] = _float_feature(tgt_mask) - - # refresh `inp_cnts` and `tgt_cnts` for each TPU core - if idx % (batch_size // num_core_per_host) == 0: - inp_cnts = [0] * len(bin_sizes) - tgt_cnts = [0] * len(bin_sizes) - - head_labels = np.copy(labels) - inp_pos_per_bin, tgt_pos_per_bin = [], [] - for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])): - inp_pos = np.where((inputs >= left) * (inputs < right))[0] - tgt_pos = np.where((labels >= left) * (labels < right))[0] - inp_pos_per_bin.append(inp_pos) - tgt_pos_per_bin.append(tgt_pos) - - head_labels[tgt_pos] = cutoffs[1] + b - - feature["head_labels"] = _int64_feature(head_labels) - - # permutation feature - def _add_perm_feature(feature, pos_per_bin, cnts, prefix): - for b, pos in enumerate(pos_per_bin): - idx_tuple = [] - for p in pos: - if cnts[b] < bin_sizes[b]: - idx_tuple.append([p, cnts[b]]) - cnts[b] += 1 - else: - break - - n_tup = len(idx_tuple) - tup = np.array(idx_tuple).reshape(n_tup * 2) - - feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup]) - feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup) - - _add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp") - _add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt") - - example = tf.train.Example(features=tf.train.Features(feature=feature)) - record_writer.write(example.SerializeToString()) - - num_batch += 1 - - record_writer.close() - print("Done writing {}. batches: {}".format(file_name, num_batch)) - - return file_name, num_batch - - -def get_lm_corpus(data_dir, dataset): - fn = os.path.join(data_dir, "cache.pkl") - - if exists(fn): - print("Loading cached dataset...") - with open(fn, "rb") as fp: - corpus = pickle.load(fp) - else: - print("Producing dataset...") - kwargs = {} - if dataset in ["wt103", "wt2"]: - kwargs["special"] = [""] - kwargs["lower_case"] = False - elif dataset == "ptb": - kwargs["special"] = [""] - kwargs["lower_case"] = True - elif dataset == "lm1b": - kwargs["special"] = [] - kwargs["lower_case"] = False - kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt") - elif dataset in ["enwik8", "text8"]: - pass - - corpus = Corpus(data_dir, dataset, **kwargs) - - print("Saving dataset...") - with open(fn, "wb") as fp: - pickle.dump(corpus, fp, protocol=2) - - corpus_info = { - "vocab_size" : len(corpus.vocab), - "cutoffs" : corpus.cutoffs, - "dataset" : corpus.dataset - } - with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp: - json.dump(corpus_info, fp) - - return corpus - - -def main(unused_argv): - del unused_argv # Unused - - corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset) - - save_dir = os.path.join(FLAGS.data_dir, "tfrecords") - if not exists(save_dir): - makedirs(save_dir) - - # test mode - if FLAGS.per_host_test_bsz > 0: - corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz, - FLAGS.tgt_len, FLAGS.num_core_per_host, - FLAGS=FLAGS) - return - - for split, batch_size in zip( - ["train", "valid"], - [FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]): - - if batch_size <= 0: continue - print("Converting {} set...".format(split)) - corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len, - FLAGS.num_core_per_host, FLAGS=FLAGS) - - -def load_record_info(record_info_dir, split, per_host_bsz, tgt_len, - num_core_per_host, use_tpu): - if use_tpu: - record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format( - split, per_host_bsz, tgt_len, num_core_per_host) - else: - record_name = "record_info-{}.bsz-{}.tlen-{}.json".format( - split, per_host_bsz, tgt_len) - - record_info_path = os.path.join(record_info_dir, record_name) - with open(record_info_path, "r") as fp: - record_info = json.load(fp) - - return record_info - -def get_input_fn(record_info_dir, split, per_host_bsz, tgt_len, - num_core_per_host, num_hosts=1, use_tpu=False): - """Creates input function.""" - record_info = load_record_info(record_info_dir, split, per_host_bsz, tgt_len, - num_core_per_host, use_tpu=use_tpu) - - file_names = record_info["filenames"] - bin_sizes = record_info["bin_sizes"] - num_batch = record_info["num_batch"] - - tf.logging.info("[{}] File names {}".format(split, file_names)) - - def input_fn(params): - # per-core batch size - per_core_bsz = params["batch_size"] - - # data_dir could be a remote path, e.g., a google storage url - data_dir = params["data_dir"] - - def parser(record): - # preprocess "inp_perm" and "tgt_perm" - def _process_perm_feature(example, prefix): - for b in range(len(bin_sizes)): - cnt = example.pop("{}_cnt_{}".format(prefix, b))[0] - tup = example.pop("{}_tup_{}".format(prefix, b)) - - tup = tf.reshape( - tf.sparse_tensor_to_dense(tup), - shape=[cnt, 2]) - - # tf.float32 - perm = tf.sparse_to_dense( - sparse_indices=tup, - output_shape=[tgt_len, bin_sizes[b]], - sparse_values=1.0, - default_value=0.0) - - example["{}_perm_{}".format(prefix, b)] = perm - - # whether allow the last batch with a potentially shorter length - if use_tpu: - record_spec = { - "inputs": tf.FixedLenFeature([tgt_len], tf.int64), - "labels": tf.FixedLenFeature([tgt_len], tf.int64), - } - else: - record_spec = { - "inputs": tf.VarLenFeature(tf.int64), - "labels": tf.VarLenFeature(tf.int64), - } - - # permutation related features - if bin_sizes and use_tpu: - # tf.float32 - record_spec["inp_mask"] = tf.FixedLenFeature([tgt_len], tf.float32) - record_spec["tgt_mask"] = tf.FixedLenFeature([tgt_len], tf.float32) - - record_spec["head_labels"] = tf.FixedLenFeature([tgt_len], tf.int64) - - for b in range(len(bin_sizes)): - record_spec["inp_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64) - record_spec["inp_tup_{}".format(b)] = tf.VarLenFeature(tf.int64) - record_spec["tgt_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64) - record_spec["tgt_tup_{}".format(b)] = tf.VarLenFeature(tf.int64) - - # retrieve serialized example - example = tf.parse_single_example( - serialized=record, - features=record_spec) - - # transform permutation tuples to permutation matrices - if bin_sizes and use_tpu: - _process_perm_feature(example, "inp") - _process_perm_feature(example, "tgt") - - # cast int64 into int32 - # cast sparse to dense - for key in list(example.keys()): - val = example[key] - if tf.keras.backend.is_sparse(val): - val = tf.sparse.to_dense(val) - if val.dtype == tf.int64: - val = tf.to_int32(val) - example[key] = val - - if use_tpu: - return example - else: - return example["inputs"], example["labels"] - - file_paths = [] - for file_name in file_names: - file_path = os.path.join(data_dir, file_name) - file_paths.append(file_path) - - if split == "train": - dataset = tf.data.Dataset.from_tensor_slices(file_paths) - if len(file_paths) > 1: - dataset = dataset.shuffle(len(file_paths)).repeat() - dataset = tf.data.TFRecordDataset(dataset) - elif num_hosts > 1: - host_id = params["context"].current_host - # drop the remaining batches - num_batch_per_host = num_batch // num_hosts - - my_start_sample_id = (host_id * num_batch_per_host * num_core_per_host * - per_core_bsz) - my_sample_num = num_batch_per_host * num_core_per_host * per_core_bsz - dataset = tf.data.TFRecordDataset(dataset).skip( - my_start_sample_id).take(my_sample_num) - else: - dataset = tf.data.TFRecordDataset(dataset) - - dataset = dataset.map(parser).cache().repeat() - dataset = dataset.batch(per_core_bsz, drop_remainder=True) - dataset = dataset.prefetch(num_core_per_host * per_core_bsz) - else: - # do not shuffle, repeat or cache in evaluation - dataset = tf.data.Dataset.from_tensor_slices(file_paths) - dataset = tf.data.TFRecordDataset(dataset) - dataset = dataset.map(parser) - dataset = dataset.batch(per_core_bsz, drop_remainder=True) - - return dataset - - if split == "train" and num_hosts > 1: - record_info["num_batch"] = num_batch // num_hosts - - return input_fn, record_info - -def get_corpus_info(corpus_info_path): - with open(corpus_info_path, "r") as fp: - corpus_info = json.load(fp) - return corpus_info - -if __name__ == "__main__": - FLAGS = flags.FLAGS - flags.DEFINE_string("data_dir", None, - help="Location of the data corpus") - flags.DEFINE_enum("dataset", "wt103", - ["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"], - help="Dataset name.") - flags.DEFINE_integer("per_host_train_bsz", 60, - help="train batch size each host") - flags.DEFINE_integer("per_host_valid_bsz", 60, - help="valid batch size each host") - flags.DEFINE_integer("per_host_test_bsz", 0, - help="If > 0, enter test mode and process test set only." - "Otherwise, process train and dev sets only.") - flags.DEFINE_integer("tgt_len", 70, - help="number of tokens to predict") - flags.DEFINE_integer("max_batch", -1, - help="run in debug mode") - flags.DEFINE_integer("num_core_per_host", 8, - help="8 for TPU v2.") - flags.DEFINE_bool("debug", default=False, - help="Process only the first batch without shuffle for lm1b.") - flags.DEFINE_integer("num_procs", 1, - help="number of processes") - flags.DEFINE_integer("num_passes", 10, - help="number of passes when use_tpu=True") - flags.DEFINE_integer("num_shuffle", 4, - help="number of shuffles for lm1b") - flags.DEFINE_bool("use_tpu", True, - help="use tpu") - - tf.app.run(main) diff --git a/transformer-xl/tf/gpu_utils.py b/transformer-xl/tf/gpu_utils.py deleted file mode 100644 index ea4b1b7..0000000 --- a/transformer-xl/tf/gpu_utils.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import tensorflow as tf - -def assign_to_gpu(gpu=0, ps_dev="/DEVICE:CPU:0"): - def _assign(op): - node_def = op if isinstance(op, tf.NodeDef) else op.node_def - if node_def.op == "Variable": - return ps_dev - else: - return "/gpu:%d" % gpu - return _assign - - -def average_grads_and_vars(tower_grads_and_vars): - def average_dense(grad_and_vars): - if len(grad_and_vars) == 1: - return grad_and_vars[0][0] - - grad = grad_and_vars[0][0] - for g, _ in grad_and_vars[1:]: - grad += g - return grad / len(grad_and_vars) - - def average_sparse(grad_and_vars): - if len(grad_and_vars) == 1: - return grad_and_vars[0][0] - - indices = [] - values = [] - for g, _ in grad_and_vars: - indices += [g.indices] - values += [g.values] - indices = tf.concat(indices, 0) - values = tf.concat(values, 0) / len(grad_and_vars) - return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape) - - average_grads_and_vars = [] - for grad_and_vars in zip(*tower_grads_and_vars): - if grad_and_vars[0][0] is None: - grad = None - elif isinstance(grad_and_vars[0][0], tf.IndexedSlices): - grad = average_sparse(grad_and_vars) - else: - grad = average_dense(grad_and_vars) - # Keep in mind that the Variables are redundant because they are shared - # across towers. So .. we will just return the first tower's pointer to - # the Variable. - v = grad_and_vars[0][1] - grad_and_var = (grad, v) - average_grads_and_vars.append(grad_and_var) - return average_grads_and_vars - - -def load_from_checkpoint(saver, logdir): - sess = tf.get_default_session() - ckpt = tf.train.get_checkpoint_state(logdir) - if ckpt and ckpt.model_checkpoint_path: - if os.path.isabs(ckpt.model_checkpoint_path): - # Restores from checkpoint with absolute path. - saver.restore(sess, ckpt.model_checkpoint_path) - else: - # Restores from checkpoint with relative path. - saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path)) - return True - return False diff --git a/transformer-xl/tf/model.py b/transformer-xl/tf/model.py deleted file mode 100644 index bab7bee..0000000 --- a/transformer-xl/tf/model.py +++ /dev/null @@ -1,546 +0,0 @@ -import tensorflow as tf - - -def positional_embedding(pos_seq, inv_freq, bsz=None): - sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq) - pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) - if bsz is not None: - return tf.tile(pos_emb[:, None, :], [1, bsz, 1]) - else: - return pos_emb[:, None, :] - - -def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer, - scope='ff', is_training=True): - output = inp - with tf.variable_scope(scope): - output = tf.layers.dense(inp, d_inner, activation=tf.nn.relu, - kernel_initializer=kernel_initializer, - name='layer_1') - output = tf.layers.dropout(output, dropout, training=is_training, - name='drop_1') - output = tf.layers.dense(output, d_model, - kernel_initializer=kernel_initializer, - name='layer_2') - output = tf.layers.dropout(output, dropout, training=is_training, - name='drop_2') - output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1) - return output - - -def rel_shift(x): - x_size = tf.shape(x) - - x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) - x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]]) - x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) - x = tf.reshape(x, x_size) - - return x - - -def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model, - n_head, d_head, dropout, dropatt, is_training, - kernel_initializer, scope='rel_attn'): - scale = 1 / (d_head ** 0.5) - with tf.variable_scope(scope): - qlen = tf.shape(w)[0] - rlen = tf.shape(r)[0] - bsz = tf.shape(w)[1] - - cat = tf.concat([mems, w], - 0) if mems is not None and mems.shape.ndims > 1 else w - w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False, - kernel_initializer=kernel_initializer, name='qkv') - r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False, - kernel_initializer=kernel_initializer, name='r') - - w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1) - w_head_q = w_head_q[-qlen:] - - klen = tf.shape(w_head_k)[0] - - w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head]) - w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head]) - w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head]) - - r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head]) - - rw_head_q = w_head_q + r_w_bias - rr_head_q = w_head_q + r_r_bias - - AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k) - BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k) - BD = rel_shift(BD) - - attn_score = (AC + BD) * scale - attn_mask_t = attn_mask[:, :, None, None] - attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t - - attn_prob = tf.nn.softmax(attn_score, 1) - attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training) - - attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v) - size_t = tf.shape(attn_vec) - attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head]) - - attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False, - kernel_initializer=kernel_initializer, name='o') - attn_out = tf.layers.dropout(attn_out, dropout, training=is_training) - - output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1) - return output - - -def embedding_lookup(lookup_table, x, use_tpu=True): - if use_tpu: - n_token = tf.shape(lookup_table)[0] - one_hot_idx = tf.one_hot(x, n_token) - if one_hot_idx.shape.ndims == 2: - return tf.einsum('nd,in->id', lookup_table, one_hot_idx) - else: - return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx) - else: - return tf.nn.embedding_lookup(lookup_table, x) - - -def mask_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer, - proj_initializer, div_val=1, - proj_same_dim=True, - scope='adaptive_embed', **kwargs): - emb_scale = d_proj ** 0.5 - with tf.variable_scope(scope): - if div_val == 1: - lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], - initializer=initializer) - y = embedding_lookup(lookup_table, x, use_tpu=False) - if d_proj != d_embed: - proj_W = tf.get_variable('proj_W', [d_embed, d_proj], - initializer=proj_initializer) - y = tf.einsum('ibe,ed->ibd', y, proj_W) - else: - proj_W = None - ret_params = [lookup_table, proj_W] - else: - tables, projs = [], [] - cutoff_ends = [0] + cutoffs + [n_token] - x_size = tf.shape(x) - y = tf.zeros([x_size[0], x_size[1], d_proj]) - for i in range(len(cutoff_ends) - 1): - with tf.variable_scope('cutoff_{}'.format(i)): - l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] - mask = (x >= l_idx) & (x < r_idx) - cur_x = tf.boolean_mask(x, mask) - l_idx - cur_d_embed = d_embed // (div_val ** i) - lookup_table = tf.get_variable('lookup_table', - [r_idx - l_idx, cur_d_embed], - initializer=initializer) - cur_y = embedding_lookup(lookup_table, cur_x, use_tpu=False) - if d_proj == cur_d_embed and not proj_same_dim: - proj_W = None - else: - proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj], - initializer=proj_initializer) - cur_y = tf.einsum('id,de->ie', cur_y, proj_W) - mask_idx = tf.to_int64(tf.where(mask)) - y += tf.scatter_nd(mask_idx, cur_y, tf.to_int64(tf.shape(y))) - tables.append(lookup_table) - projs.append(proj_W) - ret_params = [tables, projs] - - y *= emb_scale - return y, ret_params - - -def mul_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer, - proj_initializer, div_val=1, perms=None, - proj_same_dim=True, - scope='adaptive_embed'): - """ - perms: If None, first compute W = W1 x W2 (projection for each bin), - and then compute X x W (embedding lookup). If not None, - use bin-based embedding lookup with max_bin_size defined by - the shape of perms. - """ - emb_scale = d_proj ** 0.5 - with tf.variable_scope(scope): - if div_val == 1: - lookup_table = tf.get_variable('lookup_table', [n_token, d_embed], - initializer=initializer) - y = embedding_lookup(lookup_table, x) - if d_proj != d_embed: - proj_W = tf.get_variable('proj_W', [d_embed, d_proj], - initializer=proj_initializer) - y = tf.einsum('ibe,ed->ibd', y, proj_W) - else: - proj_W = None - ret_params = [lookup_table, proj_W] - else: - tables, projs = [], [] - cutoff_ends = [0] + cutoffs + [n_token] - x_size = tf.shape(x) - if perms is None: - cat_lookup = [] - else: - cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj]) - for i in range(len(cutoff_ends) - 1): - with tf.variable_scope('cutoff_{}'.format(i)): - l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] - cur_d_embed = d_embed // (div_val ** i) - lookup_table = tf.get_variable('lookup_table', - [r_idx - l_idx, cur_d_embed], - initializer=initializer) - if cur_d_embed == d_proj and not proj_same_dim: - proj_W = None - else: - proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj], - initializer=proj_initializer) - if perms is None: - cat_lookup.append(tf.einsum('ie,ed->id', lookup_table, proj_W)) - else: - # speed up the computation of the first bin - # also save some meory - if i == 0: - cur_y = embedding_lookup(lookup_table, tf.minimum(x, r_idx - 1)) - if proj_W is not None: - cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W) - cur_y *= perms[i][:, :, None] - cat_lookup += cur_y - else: - cur_x = tf.einsum('ib,ibk->k', tf.to_float(x - l_idx), perms[i]) - cur_x = tf.to_int32(cur_x) - cur_y = embedding_lookup(lookup_table, cur_x) - if proj_W is not None: - cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W) - cat_lookup += tf.einsum('kd,ibk->ibd', cur_y, perms[i]) - tables.append(lookup_table) - projs.append(proj_W) - if perms is None: - cat_lookup = tf.concat(cat_lookup, 0) - y = embedding_lookup(cat_lookup, x) - else: - y = cat_lookup - ret_params = [tables, projs] - - y *= emb_scale - return y, ret_params - - -def mask_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs, - params, tie_projs, - initializer=None, proj_initializer=None, - div_val=1, scope='adaptive_softmax', - proj_same_dim=True, - return_mean=True, **kwargs): - def _logit(x, W, b, proj): - y = x - if proj is not None: - y = tf.einsum('ibd,ed->ibe', y, proj) - return tf.einsum('ibd,nd->ibn', y, W) + b - - params_W, params_projs = params[0], params[1] - - def _gather_logprob(logprob, target): - lp_size = tf.shape(logprob) - r = tf.range(lp_size[0]) - idx = tf.stack([r, target], 1) - return tf.gather_nd(logprob, idx) - - with tf.variable_scope(scope): - if len(cutoffs) == 0: - softmax_b = tf.get_variable('bias', [n_token], - initializer=tf.zeros_initializer()) - output = _logit(hidden, params_W, softmax_b, params_projs) - nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, - logits=output) - else: - cutoff_ends = [0] + cutoffs + [n_token] - nll = tf.zeros_like(target, dtype=tf.float32) - for i in range(len(cutoff_ends) - 1): - with tf.variable_scope('cutoff_{}'.format(i)): - l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] - mask = (target >= l_idx) & (target < r_idx) - mask_idx = tf.where(mask) - cur_target = tf.boolean_mask(target, mask) - l_idx - cur_d_embed = d_embed // (div_val ** i) - - if div_val == 1: - cur_W = params_W[l_idx: r_idx] - else: - cur_W = params_W[i] - cur_b = tf.get_variable('b', [r_idx - l_idx], - initializer=tf.zeros_initializer()) - if tie_projs[i]: - if div_val == 1: - cur_proj = params_projs - else: - cur_proj = params_projs[i] - else: - if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed: - cur_proj = None - else: - cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj], - initializer=proj_initializer) - if i == 0: - cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed], - initializer=tf.zeros_initializer()) - cluster_b = tf.get_variable('cluster_b', [len(cutoffs)], - initializer=tf.zeros_initializer()) - cur_W = tf.concat([cur_W, cluster_W], 0) - cur_b = tf.concat([cur_b, cluster_b], 0) - - head_logit = _logit(hidden, cur_W, cur_b, cur_proj) - head_logprob = tf.nn.log_softmax(head_logit) - cur_head_logprob = tf.boolean_mask(head_logprob, mask) - cur_logprob = _gather_logprob(cur_head_logprob, cur_target) - else: - cur_head_logprob = tf.boolean_mask(head_logprob, mask) - cur_hidden = tf.boolean_mask(hidden, mask) - tail_logit = tf.squeeze(_logit( - cur_hidden[None], cur_W, cur_b, cur_proj), 0) - tail_logprob = tf.nn.log_softmax(tail_logit) - cur_logprob = (cur_head_logprob[:, cutoff_ends[1] + i - 1] + - _gather_logprob(tail_logprob, cur_target)) - nll += tf.scatter_nd(mask_idx, -cur_logprob, - tf.to_int64(tf.shape(nll))) - if return_mean: - nll = tf.reduce_mean(nll) - return nll - - -def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs, - params, tie_projs, - initializer=None, proj_initializer=None, - div_val=1, perms=None, proj_same_dim=True, - scope='adaptive_softmax', - **kwargs): - def _logit(x, W, b, proj): - y = x - if x.shape.ndims == 3: - if proj is not None: - y = tf.einsum('ibd,ed->ibe', y, proj) - return tf.einsum('ibd,nd->ibn', y, W) + b - else: - if proj is not None: - y = tf.einsum('id,ed->ie', y, proj) - return tf.einsum('id,nd->in', y, W) + b - - params_W, params_projs = params[0], params[1] - - with tf.variable_scope(scope): - if len(cutoffs) == 0: - softmax_b = tf.get_variable('bias', [n_token], - initializer=tf.zeros_initializer()) - output = _logit(hidden, params_W, softmax_b, params_projs) - nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target, - logits=output) - nll = tf.reduce_mean(nll) - else: - total_loss, total_cnt = 0, 0 - cutoff_ends = [0] + cutoffs + [n_token] - for i in range(len(cutoff_ends) - 1): - with tf.variable_scope('cutoff_{}'.format(i)): - l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1] - - cur_d_embed = d_embed // (div_val ** i) - - if div_val == 1: - cur_W = params_W[l_idx: r_idx] - else: - cur_W = params_W[i] - cur_b = tf.get_variable('b', [r_idx - l_idx], - initializer=tf.zeros_initializer()) - if tie_projs[i]: - if div_val == 1: - cur_proj = params_projs - else: - cur_proj = params_projs[i] - else: - if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed: - cur_proj = None - else: - cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj], - initializer=proj_initializer) - - if i == 0: - cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed], - initializer=tf.zeros_initializer()) - cluster_b = tf.get_variable('cluster_b', [len(cutoffs)], - initializer=tf.zeros_initializer()) - cur_W = tf.concat([cur_W, cluster_W], 0) - cur_b = tf.concat([cur_b, cluster_b], 0) - - head_logit = _logit(hidden, cur_W, cur_b, cur_proj) - - head_target = kwargs.get("head_target") - head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=head_target, - logits=head_logit) - - masked_loss = head_nll * perms[i] - total_loss += tf.reduce_sum(masked_loss) - total_cnt += tf.reduce_sum(perms[i]) - - # head_logprob = tf.nn.log_softmax(head_logit) - - # final_logprob = head_logprob * perms[i][:, :, None] - # final_target = tf.one_hot(target, tf.shape(head_logprob)[2]) - # total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target) - # total_cnt += tf.reduce_sum(perms[i]) - else: - cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i]) - - cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i]) - tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj) - - tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx), - perms[i]) - tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=tf.to_int32(tail_target), - logits=tail_logit) - - sum_nll = cur_head_nll + tail_nll - mask = tf.reduce_sum(perms[i], [0, 1]) - - masked_loss = sum_nll * mask - total_loss += tf.reduce_sum(masked_loss) - total_cnt += tf.reduce_sum(mask) - - nll = total_loss / total_cnt - - return nll - - -def _create_mask(qlen, mlen, same_length=False): - attn_mask = tf.ones([qlen, qlen]) - mask_u = tf.matrix_band_part(attn_mask, 0, -1) - mask_dia = tf.matrix_band_part(attn_mask, 0, 0) - attn_mask_pad = tf.zeros([qlen, mlen]) - ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1) - if same_length: - mask_l = tf.matrix_band_part(attn_mask, -1, 0) - ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1) - return ret - -def _cache_mem(curr_out, prev_mem, mem_len=None): - if mem_len is None or prev_mem is None: - new_mem = curr_out - elif mem_len == 0: - return prev_mem - else: - new_mem = tf.concat([prev_mem, curr_out], 0)[- mem_len:] - - return tf.stop_gradient(new_mem) - - -def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed, - n_head, d_head, d_inner, dropout, dropatt, - initializer, is_training, proj_initializer=None, - mem_len=None, cutoffs=[], div_val=1, tie_projs=[], - same_length=False, clamp_len=-1, use_tpu=True, - input_perms=None, target_perms=None, head_target=None, - untie_r=False, proj_same_dim=True, - scope='transformer'): - """ - cutoffs: a list of python int. Cutoffs for adaptive softmax. - tie_projs: a list of python bools. Whether to tie the projections. - use_tpu: if True, use one_hot in embedding lookup and bin-based implementation - of adaptive softmax. - perms: a list of tensors. Each tensor should of size [len, bsz, bin_size]. - Only used in the adaptive setting. - """ - new_mems = [] - with tf.variable_scope(scope): - if untie_r: - r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head], - initializer=initializer) - r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head], - initializer=initializer) - else: - r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head], - initializer=initializer) - r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head], - initializer=initializer) - - qlen = tf.shape(dec_inp)[0] - mlen = tf.shape(mems[0])[0] if mems is not None else 0 - klen = mlen + qlen - - if proj_initializer is None: - proj_initializer = initializer - lookup_fn = (mul_adaptive_embedding_lookup if use_tpu else - mask_adaptive_embedding_lookup) - embeddings, shared_params = lookup_fn( - x=dec_inp, - n_token=n_token, - d_embed=d_embed, - d_proj=d_model, - cutoffs=cutoffs, - initializer=initializer, - proj_initializer=proj_initializer, - div_val= div_val, - perms=input_perms, - proj_same_dim=proj_same_dim) - - attn_mask = _create_mask(qlen, mlen, same_length) - - pos_seq = tf.range(klen - 1, -1, -1.0) - if clamp_len > 0: - pos_seq = tf.minimum(pos_seq, clamp_len) - inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model)) - pos_emb = positional_embedding(pos_seq, inv_freq) - - output = tf.layers.dropout(embeddings, dropout, training=is_training) - pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training) - - if mems is None: - mems = [None] * n_layer - - for i in range(n_layer): - # cache new mems - new_mems.append(_cache_mem(output, mems[i], mem_len)) - - with tf.variable_scope('layer_{}'.format(i)): - output = rel_multihead_attn( - w=output, - r=pos_emb, - r_w_bias=r_w_bias if not untie_r else r_w_bias[i], - r_r_bias=r_r_bias if not untie_r else r_r_bias[i], - attn_mask=attn_mask, - mems=mems[i], - d_model=d_model, - n_head=n_head, - d_head=d_head, - dropout=dropout, - dropatt=dropatt, - is_training=is_training, - kernel_initializer=initializer) - output = positionwise_FF( - inp=output, - d_model=d_model, - d_inner=d_inner, - dropout=dropout, - kernel_initializer=initializer, - is_training=is_training) - - output = tf.layers.dropout(output, dropout, training=is_training) - - logsoftmax_fn = (mul_adaptive_logsoftmax if use_tpu else - mask_adaptive_logsoftmax) - loss = logsoftmax_fn( - hidden=output, - target=target, - n_token=n_token, - d_embed=d_embed, - d_proj=d_model, - cutoffs=cutoffs, - params=shared_params, - tie_projs=tie_projs, - initializer=initializer, - proj_initializer=proj_initializer, - div_val=div_val, - perms=target_perms, - head_target=head_target, - proj_same_dim=proj_same_dim) - return loss, new_mems - diff --git a/transformer-xl/tf/scripts/enwik8_base_gpu.sh b/transformer-xl/tf/scripts/enwik8_base_gpu.sh deleted file mode 100644 index 6de09a0..0000000 --- a/transformer-xl/tf/scripts/enwik8_base_gpu.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=../data/enwik8/ - -# Model -N_LAYER=12 -D_MODEL=512 -D_EMBED=512 -N_HEAD=8 -D_HEAD=64 -D_INNER=2048 - -# Training -TGT_LEN=512 -MEM_LEN=512 - -BSZ=24 -NUM_CORE=4 - -# Testing -TEST_TGT_LEN=80 -TEST_MEM_LEN=2100 -TEST_CLAMP_LEN=820 - -TEST_BSZ=10 -TEST_NUM_CORE=1 - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=enwik8 \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${BSZ} \ - --per_host_valid_bsz=${BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=enwik8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-enwik8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.1 \ - --dropatt=0.0 \ - --learning_rate=0.00025 \ - --warmup_steps=0 \ - --train_steps=400000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=200 \ - --save_steps=4000 \ - --do_train=True \ - --do_eval=False \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-enwik8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test \ - ${@:2} -else - echo 'unknown argment 1' -fi \ No newline at end of file diff --git a/transformer-xl/tf/scripts/enwik8_large_tpu.sh b/transformer-xl/tf/scripts/enwik8_large_tpu.sh deleted file mode 100644 index e862fd7..0000000 --- a/transformer-xl/tf/scripts/enwik8_large_tpu.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/bin/bash - -# Path -LOCAL_DIR=../data/enwik8/ -GSDATA= -GSEXP= - -# TPU setting -NUM_HOST=2 -NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 - -TEST_NUM_HOST=1 -TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 - -# Model -N_LAYER=24 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=8 -D_HEAD=128 -D_INNER=3072 - -# Training -TGT_LEN=768 -MEM_LEN=768 -TRAIN_BSZ=64 -VALID_BSZ=64 - -# Testing -TEST_TGT_LEN=128 -TEST_MEM_LEN=3800 -TEST_CLAMP_LEN=1000 -TEST_BSZ=16 - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=enwik8 \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${TRAIN_BSZ} \ - --per_host_valid_bsz=${VALID_BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --num_passes=10 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/ - - SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/ - -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=enwik8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --num_passes=1 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/ - -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --data_dir=${GSDATA}/enwik8-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/enwik8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.15 \ - --dropatt=0.15 \ - --learning_rate=0.00025 \ - --warmup_steps=4000 \ - --train_steps=400000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${TRAIN_BSZ} \ - --use_tpu=True \ - --num_host=${NUM_HOST} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=1000 \ - --save_steps=10000 \ - --do_train=True \ - --do_eval=False \ - ${@:2} - -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train.py \ - --data_dir=${GSDATA}/enwik8-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/enwik8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --eval_batch_size=${TEST_BSZ} \ - --num_host=${TEST_NUM_HOST} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --use_tpu=True \ - --do_train=False \ - --do_eval_only=True \ - --eval_split=test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/tf/scripts/lm1b_base_gpu.sh b/transformer-xl/tf/scripts/lm1b_base_gpu.sh deleted file mode 100644 index 2dcb252..0000000 --- a/transformer-xl/tf/scripts/lm1b_base_gpu.sh +++ /dev/null @@ -1,110 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=../data/one-billion-words/ - -# Model -DIV_VAL=4 -N_LAYER=18 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=8 -D_HEAD=128 -D_INNER=4096 - -# Training -TGT_LEN=256 -MEM_LEN=256 - -BSZ=256 -NUM_CORE=4 - -# Testing -TEST_TGT_LEN=32 -TEST_MEM_LEN=128 -TEST_CLAMP_LEN=-1 - -TEST_BSZ=16 -TEST_NUM_CORE=1 - - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=lm1b \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${BSZ} \ - --per_host_valid_bsz=${BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=lm1b \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-lm1b \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=False \ - --proj_same_dim=False \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.1 \ - --dropatt=0.0 \ - --learning_rate=0.00025 \ - --warmup_steps=0 \ - --train_steps=400000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=200 \ - --save_steps=4000 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-lm1b \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=False \ - --proj_same_dim=False \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/tf/scripts/lm1b_large_tpu.sh b/transformer-xl/tf/scripts/lm1b_large_tpu.sh deleted file mode 100644 index 076478e..0000000 --- a/transformer-xl/tf/scripts/lm1b_large_tpu.sh +++ /dev/null @@ -1,136 +0,0 @@ -#!/bin/bash - -# Path -LOCAL_DIR=../data/one-billion-words/ -GSDATA= -GSEXP= - -# TPU setting -NUM_HOST=32 -NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 - -TEST_NUM_HOST=1 -TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 - -# Model -DIV_VAL=4 -N_LAYER=24 -D_MODEL=1280 -D_EMBED=1280 -N_HEAD=16 -D_HEAD=80 -D_INNER=8192 - -# Training -TGT_LEN=32 -MEM_LEN=32 -TRAIN_BSZ=512 -VALID_BSZ=512 -TRAIN_BSZ_PER_HOST=$((TRAIN_BSZ / NUM_HOST)) -VALID_BSZ_PER_HOST=$((VALID_BSZ / NUM_HOST)) - -# Testing -TEST_TGT_LEN=32 -TEST_MEM_LEN=128 -TEST_CLAMP_LEN=-1 -TEST_BSZ=8 - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=lm1b \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${TRAIN_BSZ_PER_HOST} \ - --per_host_valid_bsz=${VALID_BSZ_PER_HOST} \ - --num_core_per_host=${NUM_CORE} \ - --num_passes=10 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/ - - SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/ - -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=lm1b \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --num_passes=1 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/ - -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --data_dir=${GSDATA}/lm1b-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/lm1b \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=False \ - --proj_same_dim=False \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.05 \ - --dropatt=0.05 \ - --init_std=0.005 \ - --learning_rate=0.0001 \ - --warmup_steps=30000 \ - --train_steps=1200000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${TRAIN_BSZ} \ - --num_hosts=${NUM_HOST} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=1000 \ - --save_steps=10000 \ - --use_tpu=True \ - --do_eval=False \ - ${@:2} - -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train.py \ - --data_dir=${GSDATA}/lm1b-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/lm1b \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=False \ - --proj_same_dim=False \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_host=${TEST_NUM_HOST} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --use_tpu=True \ - --do_train=False \ - --do_eval_only=True \ - --eval_split=test \ - ${@:2} - -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/tf/scripts/text8_base_gpu.sh b/transformer-xl/tf/scripts/text8_base_gpu.sh deleted file mode 100644 index 1cff08a..0000000 --- a/transformer-xl/tf/scripts/text8_base_gpu.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=../data/text8/ - -# Model -N_LAYER=12 -D_MODEL=512 -D_EMBED=512 -N_HEAD=8 -D_HEAD=64 -D_INNER=2048 - -# Training -TGT_LEN=512 -MEM_LEN=512 - -BSZ=24 -NUM_CORE=4 - -# Testing -TEST_TGT_LEN=80 -TEST_MEM_LEN=2100 -TEST_CLAMP_LEN=820 - -TEST_BSZ=10 -TEST_NUM_CORE=1 - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=text8 \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${BSZ} \ - --per_host_valid_bsz=${BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=text8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-text8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.1 \ - --dropatt=0.0 \ - --learning_rate=0.00025 \ - --warmup_steps=0 \ - --train_steps=400000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=200 \ - --save_steps=4000 \ - --do_train=True \ - --do_eval=False \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-text8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test \ - ${@:2} -else - echo 'unknown argment 1' -fi \ No newline at end of file diff --git a/transformer-xl/tf/scripts/text8_large_tpu.sh b/transformer-xl/tf/scripts/text8_large_tpu.sh deleted file mode 100644 index afcbbf5..0000000 --- a/transformer-xl/tf/scripts/text8_large_tpu.sh +++ /dev/null @@ -1,122 +0,0 @@ -#!/bin/bash - -# Path -LOCAL_DIR=../data/text8/ -GSDATA= -GSEXP= - -# TPU setting -NUM_HOST=2 -NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 - -TEST_NUM_HOST=1 -TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 - -# Model -N_LAYER=24 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=8 -D_HEAD=128 -D_INNER=3072 - -# Training -TGT_LEN=768 -MEM_LEN=768 -TRAIN_BSZ=64 -VALID_BSZ=64 - -# Testing -TEST_TGT_LEN=128 -TEST_MEM_LEN=3800 -TEST_CLAMP_LEN=1000 -TEST_BSZ=16 - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=text8 \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${TRAIN_BSZ} \ - --per_host_valid_bsz=${VALID_BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --num_passes=10 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/ - - SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/ - -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=text8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --num_passes=1 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/ - -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --data_dir=${GSDATA}/text8-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/text8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.15 \ - --dropatt=0.15 \ - --learning_rate=0.00025 \ - --warmup_steps=4000 \ - --train_steps=400000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${TRAIN_BSZ} \ - --use_tpu=True \ - --num_host=${NUM_HOST} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=1000 \ - --save_steps=10000 \ - --do_train=True \ - --do_eval=False \ - ${@:2} - -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train.py \ - --data_dir=${GSDATA}/text8-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/text8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --eval_batch_size=${TEST_BSZ} \ - --num_host=${TEST_NUM_HOST} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --use_tpu=True \ - --do_train=False \ - --do_eval_only=True \ - --eval_split=test \ - ${@:2} -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/tf/scripts/wt103_base_gpu.sh b/transformer-xl/tf/scripts/wt103_base_gpu.sh deleted file mode 100644 index c3bc810..0000000 --- a/transformer-xl/tf/scripts/wt103_base_gpu.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=../data/wikitext-103/ - -# Model -DIV_VAL=1 -N_LAYER=16 -D_MODEL=410 -D_EMBED=410 -N_HEAD=10 -D_HEAD=41 -D_INNER=2100 - -# Training -TGT_LEN=150 -MEM_LEN=150 - -BSZ=60 -NUM_CORE=4 - -# Testing -TEST_TGT_LEN=64 -TEST_MEM_LEN=640 -TEST_CLAMP_LEN=400 - -TEST_BSZ=10 -TEST_NUM_CORE=1 - - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=wt103 \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${BSZ} \ - --per_host_valid_bsz=${BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${DATA_ROOT}/ \ - --dataset=enwik8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False \ - ${@:2} -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-wt103 \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=True \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.1 \ - --dropatt=0.0 \ - --learning_rate=0.00025 \ - --warmup_steps=0 \ - --train_steps=400000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=200 \ - --save_steps=4000 \ - ${@:2} -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train_gpu.py \ - --data_dir=${DATA_ROOT}/tfrecords \ - --record_info_dir=${DATA_ROOT}/tfrecords/ \ - --corpus_info_path=${DATA_ROOT}/corpus-info.json \ - --model_dir=EXP-wt103 \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=True \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test \ - ${@:2} -else - echo 'unknown argment 1' -fi \ No newline at end of file diff --git a/transformer-xl/tf/scripts/wt103_large_tpu.sh b/transformer-xl/tf/scripts/wt103_large_tpu.sh deleted file mode 100644 index c32fbcd..0000000 --- a/transformer-xl/tf/scripts/wt103_large_tpu.sh +++ /dev/null @@ -1,134 +0,0 @@ -#!/bin/bash - -# Path -LOCAL_DIR=../data/wikitext-103/ -GSDATA= -GSEXP= - -# TPU setting -NUM_HOST=4 -NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16 - -TEST_NUM_HOST=1 -TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16 - -# Model -DIV_VAL=4 -N_LAYER=18 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=16 -D_HEAD=64 -D_INNER=4096 - -# Training -TGT_LEN=384 -MEM_LEN=384 -TRAIN_BSZ=128 -VALID_BSZ=128 - -# Testing -TEST_TGT_LEN=128 -TEST_MEM_LEN=1600 -TEST_CLAMP_LEN=1000 -TEST_BSZ=8 - -if [[ $1 == 'train_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=wt103 \ - --tgt_len=${TGT_LEN} \ - --per_host_train_bsz=${TRAIN_BSZ} \ - --per_host_valid_bsz=${VALID_BSZ} \ - --num_core_per_host=${NUM_CORE} \ - --num_passes=10 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/ - - SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/ - -elif [[ $1 == 'test_data' ]]; then - python data_utils.py \ - --data_dir=${LOCAL_DIR}/ \ - --dataset=wt103 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --num_passes=1 \ - --use_tpu=True \ - ${@:2} - - SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}* - gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/ - -elif [[ $1 == 'train' ]]; then - echo 'Run training...' - python train.py \ - --data_dir=${GSDATA}/wt103-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/wt103 \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=True \ - --proj_same_dim=True \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.2 \ - --dropatt=0.2 \ - --init_std=0.005 \ - --learning_rate=0.00025 \ - --warmup_steps=16000 \ - --train_steps=4000000 \ - --tgt_len=${TGT_LEN} \ - --mem_len=${MEM_LEN} \ - --train_batch_size=${TRAIN_BSZ} \ - --num_hosts=${NUM_HOST} \ - --num_core_per_host=${NUM_CORE} \ - --iterations=1000 \ - --save_steps=10000 \ - --use_tpu=True \ - --do_eval=False \ - ${@:2} - -elif [[ $1 == 'eval' ]]; then - echo 'Run evaluation...' - python train.py \ - --data_dir=${GSDATA}/wt103-tfrecords \ - --record_info_dir=${LOCAL_DIR}/tfrecords/ \ - --corpus_info_path=${LOCAL_DIR}/corpus-info.json \ - --model_dir=${GSEXP}/wt103 \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=True \ - --proj_same_dim=True \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_host=${TEST_NUM_HOST} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --use_tpu=True \ - --do_train=False \ - --do_eval_only=True \ - --eval_split=test \ - ${@:2} - -else - echo 'unknown argment 1' -fi diff --git a/transformer-xl/tf/sota/download.sh b/transformer-xl/tf/sota/download.sh deleted file mode 100644 index 9a8db16..0000000 --- a/transformer-xl/tf/sota/download.sh +++ /dev/null @@ -1,87 +0,0 @@ -#!/bin/bash - -URL=http://curtis.ml.cmu.edu/datasets/pretrained_xl - -DATA_ROOT=./ - -function download () { - fileurl=${1} - filename=${fileurl##*/} - if [ ! -f ${filename} ]; then - echo ">>> Download '${filename}' from '${fileurl}'." - wget --quiet ${fileurl} - else - echo "*** File '${filename}' exists. Skip." - fi -} - -cd $DATA_ROOT -mkdir -p pretrained_xl && cd pretrained_xl - -# enwik8 -mkdir -p tf_enwik8 && cd tf_enwik8 - -mkdir -p data && cd data -download ${URL}/tf_enwiki8/data/cache.pkl -download ${URL}/tf_enwiki8/data/corpus-info.json -cd .. - -mkdir -p model && cd model -download ${URL}/tf_enwiki8/model/checkpoint -download ${URL}/tf_enwiki8/model/model.ckpt-0.data-00000-of-00001 -download ${URL}/tf_enwiki8/model/model.ckpt-0.index -download ${URL}/tf_enwiki8/model/model.ckpt-0.meta -cd .. - -cd .. - -# text8 -mkdir -p tf_text8 && cd tf_text8 - -mkdir -p data && cd data -download ${URL}/tf_text8/data/cache.pkl -download ${URL}/tf_text8/data/corpus-info.json -cd .. - -mkdir -p model && cd model -download ${URL}/tf_text8/model/checkpoint -download ${URL}/tf_text8/model/model.ckpt-0.data-00000-of-00001 -download ${URL}/tf_text8/model/model.ckpt-0.index -download ${URL}/tf_text8/model/model.ckpt-0.meta -cd .. - -cd .. - -# wt103 -mkdir -p tf_wt103 && cd tf_wt103 - -mkdir -p data && cd data -download ${URL}/tf_wt103/data/cache.pkl -download ${URL}/tf_wt103/data/corpus-info.json -cd .. - -mkdir -p model && cd model -download ${URL}/tf_wt103/model/checkpoint -download ${URL}/tf_wt103/model/model.ckpt-0.data-00000-of-00001 -download ${URL}/tf_wt103/model/model.ckpt-0.index -download ${URL}/tf_wt103/model/model.ckpt-0.meta -cd .. - -cd .. - -# lm1b -mkdir -p tf_lm1b && cd tf_lm1b - -mkdir -p data && cd data -download ${URL}/tf_lm1b/data/cache.pkl -download ${URL}/tf_lm1b/data/corpus-info.json -cd .. - -mkdir -p model && cd model -download ${URL}/tf_lm1b/model/checkpoint -download ${URL}/tf_lm1b/model/model.ckpt-1191000.data-00000-of-00001 -download ${URL}/tf_lm1b/model/model.ckpt-1191000.index -download ${URL}/tf_lm1b/model/model.ckpt-1191000.meta -cd .. - -cd .. diff --git a/transformer-xl/tf/sota/enwik8.sh b/transformer-xl/tf/sota/enwik8.sh deleted file mode 100644 index 27b45f0..0000000 --- a/transformer-xl/tf/sota/enwik8.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=./ -DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/data -MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/model - -# Model -N_LAYER=24 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=8 -D_HEAD=128 -D_INNER=3072 - -# Testing -TEST_TGT_LEN=128 -TEST_MEM_LEN=3800 -TEST_CLAMP_LEN=1000 - -TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0 -TEST_BSZ=16 -TEST_NUM_CORE=2 - - -echo 'Preprocess test set...' -python data_utils.py \ - --data_dir=${DATA_DIR}/ \ - --dataset=enwik8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False - -echo 'Run evaluation on test set...' -python train_gpu.py \ - --data_dir=${DATA_DIR}/tfrecords \ - --record_info_dir=${DATA_DIR}/tfrecords/ \ - --corpus_info_path=${DATA_DIR}/corpus-info.json \ - --eval_ckpt_path=${TEST_CKPT_PATH} \ - --model_dir=EXP-enwik8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test diff --git a/transformer-xl/tf/sota/lm1b.sh b/transformer-xl/tf/sota/lm1b.sh deleted file mode 100644 index bd49918..0000000 --- a/transformer-xl/tf/sota/lm1b.sh +++ /dev/null @@ -1,63 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=./ -DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_lm1b/data -MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_lm1b/model - -# Model -DIV_VAL=4 -N_LAYER=24 -D_MODEL=1280 -D_EMBED=1280 -N_HEAD=16 -D_HEAD=80 -D_INNER=8192 - -# Testing -TEST_TGT_LEN=32 -TEST_MEM_LEN=128 -TEST_CLAMP_LEN=-1 - -TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-1191000 -TEST_BSZ=16 -TEST_NUM_CORE=1 - - -echo 'Preprocess test set...' -python data_utils.py \ - --data_dir=${DATA_DIR}/ \ - --dataset=lm1b \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False - -echo 'Run evaluation on test set...' -python train_gpu.py \ - --data_dir=${DATA_DIR}/tfrecords \ - --record_info_dir=${DATA_DIR}/tfrecords/ \ - --corpus_info_path=${DATA_DIR}/corpus-info.json \ - --eval_ckpt_path=${TEST_CKPT_PATH} \ - --model_dir=EXP-lm1b \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=False \ - --proj_same_dim=False \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test diff --git a/transformer-xl/tf/sota/text8.sh b/transformer-xl/tf/sota/text8.sh deleted file mode 100644 index 5d9d8f5..0000000 --- a/transformer-xl/tf/sota/text8.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=./ -DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_text8/data -MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_text8/model - -# Model -N_LAYER=24 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=8 -D_HEAD=128 -D_INNER=3072 - -# Testing -TEST_TGT_LEN=128 -TEST_MEM_LEN=3800 -TEST_CLAMP_LEN=1000 - -TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0 -TEST_BSZ=16 -TEST_NUM_CORE=2 - - -echo 'Preprocess test set...' -python data_utils.py \ - --data_dir=${DATA_DIR}/ \ - --dataset=text8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False - -echo 'Run evaluation on test set...' -python train_gpu.py \ - --data_dir=${DATA_DIR}/tfrecords \ - --record_info_dir=${DATA_DIR}/tfrecords/ \ - --corpus_info_path=${DATA_DIR}/corpus-info.json \ - --eval_ckpt_path=${TEST_CKPT_PATH} \ - --model_dir=EXP-text8 \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test diff --git a/transformer-xl/tf/sota/wt103.sh b/transformer-xl/tf/sota/wt103.sh deleted file mode 100644 index 4b7f626..0000000 --- a/transformer-xl/tf/sota/wt103.sh +++ /dev/null @@ -1,71 +0,0 @@ -#!/bin/bash - -# Data -DATA_ROOT=./ -DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_wt103/data -MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_wt103/model - -# Model -DIV_VAL=4 -N_LAYER=18 -D_MODEL=1024 -D_EMBED=1024 -N_HEAD=16 -D_HEAD=64 -D_INNER=4096 - -# Training -TGT_LEN=256 -MEM_LEN=256 - -BSZ=16 -NUM_CORE=2 - -# Testing -TEST_TGT_LEN=128 -TEST_MEM_LEN=1600 -TEST_CLAMP_LEN=1000 - -TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0 -TEST_BSZ=16 -TEST_NUM_CORE=1 - - -echo 'Preprocess test set...' -python data_utils.py \ - --data_dir=${DATA_DIR}/ \ - --dataset=enwik8 \ - --tgt_len=${TEST_TGT_LEN} \ - --per_host_test_bsz=${TEST_BSZ} \ - --num_passes=1 \ - --use_tpu=False - - -echo 'Run evaluation on test set...' -python train_gpu.py \ - --data_dir=${DATA_DIR}/tfrecords \ - --record_info_dir=${DATA_DIR}/tfrecords/ \ - --corpus_info_path=${DATA_DIR}/corpus-info.json \ - --eval_ckpt_path=${TEST_CKPT_PATH} \ - --model_dir=EXP-wt103 \ - --div_val=${DIV_VAL} \ - --untie_r=True \ - --proj_share_all_but_first=True \ - --n_layer=${N_LAYER} \ - --d_model=${D_MODEL} \ - --d_embed=${D_EMBED} \ - --n_head=${N_HEAD} \ - --d_head=${D_HEAD} \ - --d_inner=${D_INNER} \ - --dropout=0.0 \ - --dropatt=0.0 \ - --tgt_len=${TEST_TGT_LEN} \ - --mem_len=${TEST_MEM_LEN} \ - --clamp_len=${TEST_CLAMP_LEN} \ - --same_length=True \ - --eval_batch_size=${TEST_BSZ} \ - --num_core_per_host=${TEST_NUM_CORE} \ - --do_train=False \ - --do_eval=True \ - --eval_split=test - diff --git a/transformer-xl/tf/tpu_estimator.py b/transformer-xl/tf/tpu_estimator.py deleted file mode 100644 index 7bc3598..0000000 --- a/transformer-xl/tf/tpu_estimator.py +++ /dev/null @@ -1,3519 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# =================================================================== -"""TPUEstimator class.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import copy -import os -import signal -import sys -import threading -import time -import numpy as np -import six -from six.moves import queue as Queue # pylint: disable=redefined-builtin -from six.moves import xrange # pylint: disable=redefined-builtin - -import math - -try: - import google3 - from google3.third_party.tensorflow.contrib.tpu.python.ops import tpu_ops - from google3.third_party.tensorflow.contrib.tpu.python.tpu import error_handling - from google3.third_party.tensorflow.contrib.tpu.python.tpu import session_support - from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu - from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_config - from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_context - from google3.third_party.tensorflow.contrib.tpu.python.tpu import tpu_feed - from google3.third_party.tensorflow.contrib.tpu.python.tpu import training_loop - from google3.third_party.tensorflow.contrib.tpu.python.tpu import util as util_lib - from google3.third_party.tensorflow.contrib.training.python.training import hparam - from google3.third_party.tensorflow.core.framework import variable_pb2 - from google3.third_party.tensorflow.core.framework.summary_pb2 import Summary - from google3.third_party.tensorflow.core.protobuf import config_pb2 - from google3.third_party.tensorflow.python.data.ops import dataset_ops - from google3.third_party.tensorflow.python.data.util import nest as data_nest - from google3.third_party.tensorflow.python.estimator import estimator as estimator_lib - from google3.third_party.tensorflow.python.estimator import model_fn as model_fn_lib - from google3.third_party.tensorflow.python.estimator.export import export_output as export_output_lib - from google3.third_party.tensorflow.python.framework import constant_op - from google3.third_party.tensorflow.python.framework import dtypes - from google3.third_party.tensorflow.python.framework import errors - from google3.third_party.tensorflow.python.framework import ops - from google3.third_party.tensorflow.python.ops import array_ops - from google3.third_party.tensorflow.python.ops import check_ops - from google3.third_party.tensorflow.python.ops import control_flow_ops - from google3.third_party.tensorflow.python.ops import init_ops - from google3.third_party.tensorflow.python.ops import math_ops - from google3.third_party.tensorflow.python.ops import resource_variable_ops - from google3.third_party.tensorflow.python.ops import state_ops - from google3.third_party.tensorflow.python.ops import summary_ops_v2 as contrib_summary - from google3.third_party.tensorflow.python.ops import variable_scope - from google3.third_party.tensorflow.python.ops import variables - from google3.third_party.tensorflow.python.platform import tf_logging as logging - from google3.third_party.tensorflow.python.saved_model import tag_constants - from google3.third_party.tensorflow.python.summary import summary - from google3.third_party.tensorflow.python.training import basic_session_run_hooks - from google3.third_party.tensorflow.python.training import evaluation - from google3.third_party.tensorflow.python.training import session_run_hook - from google3.third_party.tensorflow.python.training import training - from google3.third_party.tensorflow.python.training import training_util - from google3.third_party.tensorflow.python.util import function_utils - from google3.third_party.tensorflow.python.util import nest - from google3.third_party.tensorflow.python.util import tf_inspect -except: - import tensorflow - from tensorflow.contrib.tpu.python.ops import tpu_ops - from tensorflow.contrib.tpu.python.tpu import error_handling - from tensorflow.contrib.tpu.python.tpu import session_support - from tensorflow.contrib.tpu.python.tpu import tpu - from tensorflow.contrib.tpu.python.tpu import tpu_config - from tensorflow.contrib.tpu.python.tpu import tpu_context - from tensorflow.contrib.tpu.python.tpu import tpu_feed - from tensorflow.contrib.tpu.python.tpu import training_loop - from tensorflow.contrib.tpu.python.tpu import util as util_lib - from tensorflow.contrib.training.python.training import hparam - from tensorflow.core.framework import variable_pb2 - from tensorflow.core.framework.summary_pb2 import Summary - from tensorflow.core.protobuf import config_pb2 - from tensorflow.python.data.ops import dataset_ops - from tensorflow.python.data.util import nest as data_nest - from tensorflow.python.estimator import estimator as estimator_lib - from tensorflow.python.estimator import model_fn as model_fn_lib - from tensorflow.python.estimator import util as estimator_util - from tensorflow.python.estimator.export import export_output as export_output_lib - from tensorflow.python.framework import constant_op - from tensorflow.python.framework import dtypes - from tensorflow.python.framework import errors - from tensorflow.python.framework import ops - from tensorflow.python.ops import array_ops - from tensorflow.python.ops import check_ops - from tensorflow.python.ops import control_flow_ops - from tensorflow.python.ops import init_ops - from tensorflow.python.ops import math_ops - from tensorflow.python.ops import resource_variable_ops - from tensorflow.python.ops import state_ops - from tensorflow.python.ops import summary_ops_v2 as contrib_summary - from tensorflow.python.ops import variable_scope - from tensorflow.python.ops import variables - from tensorflow.python.platform import tf_logging as logging - from tensorflow.python.saved_model import tag_constants - from tensorflow.python.summary import summary - from tensorflow.python.training import basic_session_run_hooks - from tensorflow.python.training import evaluation - from tensorflow.python.training import session_run_hook - from tensorflow.python.training import training - from tensorflow.python.training import training_util - from tensorflow.python.util import function_utils - from tensorflow.python.util import nest - from tensorflow.python.util import tf_inspect - - -_INITIAL_LOSS = 1e7 -_ZERO_LOSS = 0. -_TPU_ESTIMATOR = 'custom_tpu_estimator' # CHANGE FOR RECURRENCY -_ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' -_BATCH_SIZE_KEY = 'batch_size' -_CTX_KEY = 'context' -_USE_TPU_KEY = 'use_tpu' -_CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' -_ONE_GIGABYTE = 1024 * 1024 * 1024 -_TPU_ENQUEUE_OPS = '_tpu_enqueue_ops' -_TPU_TRAIN_OP = '_tpu_train_op' -_REWRITE_FOR_INFERENCE_MODE = '_rewrite_for_inference' - -# Ideally _USE_TPU_KEY should be reserved as well. However there are already -# models that make use of this key, thus it can not be reserved now to prevent -# breakage. In the long run, we would like to mitigate this by migrating models -# off of using _USE_TPU_KEY. -_RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY, _CTX_KEY] - - -# TODO(b/65703635): Flip the value and remove all dead code. Currently, this is -# only used for per-core based deployments. For per-host based pipelines, if a -# user returns a Dataset instance it will be automatically wrapped in a -# tf.while_loop (This can be disabled by returning features and labels -# explicitly). -_WRAP_INPUT_FN_INTO_WHILE_LOOP = False - - -ops.register_proto_function( - '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR), - proto_type=variable_pb2.VariableDef, - to_proto=resource_variable_ops._to_proto_fn, # pylint: disable=protected-access - from_proto=resource_variable_ops._from_proto_fn) # pylint: disable=protected-access - - -def _create_global_step(graph): - graph = graph or ops.get_default_graph() - if training.get_global_step(graph) is not None: - raise ValueError('"global_step" already exists.') - # Create in proper graph and base name_scope. - with graph.as_default() as g, g.name_scope(None): - return variable_scope.get_variable( - ops.GraphKeys.GLOBAL_STEP, - shape=[], - dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - trainable=False, - use_resource=True, - collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) - - -def _create_or_get_iterations_per_loop(): - """Creates or gets the iterations_per_loop variable. - - In TPUEstimator, the user provided computation, the model_fn, is wrapped - inside a tf.while_loop for peak performance. The iterations of the loop are - specified by this variable, which adjusts its value on the CPU after each TPU - program execution and before the next TPU execution. - - The purpose of using a variable, rather then a constant, is to allow - TPUEstimator adapt the TPU training iterations according to the final steps - specified by users. For example, if the user sets the iterations_per_loop as 4 - in TPUConfig and steps as 10 in TPUEstimator.train(), the iterations_per_loop - variable will have the following value before each TPU training. - - - 1-th TPU execution: iterations_per_loop = 4 - - 2-th TPU execution: iterations_per_loop = 4 - - 3-th TPU execution: iterations_per_loop = 2 - - As model_fn increases the global step once per train_op invocation, the global - step is 10 after all TPU executions, matching the steps=10 inputs passed in by - users. - - Returns: - A TF non-trainable resource variable. - - Raises: - RuntimeError: If multi iterations_per_loop variables were found. - """ - graph = ops.get_default_graph() - collection_name = '{}_{}'.format(_TPU_ESTIMATOR, _ITERATIONS_PER_LOOP_VAR) - iter_vars = graph.get_collection(collection_name) - if len(iter_vars) == 1: - return iter_vars[0] - elif len(iter_vars) > 1: - raise RuntimeError('Multiple iterations_per_loop_var in collection.') - - with ops.colocate_with(training_util.get_global_step()): - with variable_scope.variable_scope( - _TPU_ESTIMATOR, reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[collection_name, ops.GraphKeys.LOCAL_VARIABLES], - use_resource=True) - - -def _sync_variables_ops(): - # Gets the variables back from TPU nodes. This means the variables updated - # by TPU will now be *synced* to host memory. - return [ - array_ops.check_numerics(v.read_value(), - 'Gradient for %s is NaN' % v.name).op - for v in variables.trainable_variables() - ] - - -def _increase_eval_step_op(iterations_per_loop): - """Returns an op to increase the eval step for TPU evaluation. - - Args: - iterations_per_loop: Tensor. The number of eval steps running in TPU - system before returning to CPU host for each `Session.run`. - - Returns: - An operation - """ - eval_step = evaluation._get_or_create_eval_step() # pylint: disable=protected-access - # Estimator evaluate increases 1 by default. So, we increase the difference. - return state_ops.assign_add( - eval_step, - math_ops.cast(iterations_per_loop - 1, dtype=eval_step.dtype), - use_locking=True) - - -def _extract_key_names(tensor_or_dict): - if isinstance(tensor_or_dict, dict): - return sorted(tensor_or_dict.keys()) - return [] - - -class _SIGNAL(object): - """Signal used to control the thread of infeed/outfeed. - - All preserved signals must be negative numbers. Positive numbers are used to - indicate the number of iterations for next training/evaluation loop. - """ - NEXT_BATCH = -1 - STOP = -2 - - -class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - """Ops and objects returned from a `model_fn` and passed to `TPUEstimator`. - - See `EstimatorSpec` for `mode`, `predictions`, `loss`, `train_op`, and - `export_outputs`. - - For evaluation, `eval_metrics `is a tuple of `metric_fn` and `tensors`, where - `metric_fn` runs on CPU to generate metrics and `tensors` represents the - `Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`. - To be precise, TPU evaluation expects a slightly different signature from the - @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a - dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`. - The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The - `tensors` usually specify the model logits, which are transferred back from - TPU system to CPU host. All tensors must have be batch-major, i.e., the batch - size is the first dimension. Once all tensors are available at CPU host from - all shards, they are concatenated (on CPU) and passed as positional arguments - to the `metric_fn` if `tensors` is list or keyword arguments if `tensors` is - a dict. `metric_fn` takes the `tensors` and returns a dict from metric string - name to the result of calling a metric function, namely a `(metric_tensor, - update_op)` tuple. See `TPUEstimator` for MNIST example how to specify the - `eval_metrics`. - - `scaffold_fn` is a function running on CPU to generate the `Scaffold`. This - function should not capture any Tensors in `model_fn`. - - `host_call` is a tuple of a `function` and a list or dictionary of `tensors` - to pass to that function and returns a list of Tensors. `host_call` currently - works for train() and evaluate(). The Tensors returned by the function is - executed on the CPU on every step, so there is communication overhead when - sending tensors from TPU to CPU. To reduce the overhead, try reducing the - size of the tensors. The `tensors` are concatenated along their major (batch) - dimension, and so must be >= rank 1. The `host_call` is useful for writing - summaries with @{tf.contrib.summary.create_file_writer}. - """ - - def __new__(cls, - mode, - predictions=None, - loss=None, - train_op=None, - eval_metrics=None, - export_outputs=None, - scaffold_fn=None, - host_call=None, - training_hooks=None, - evaluation_hooks=None, - prediction_hooks=None): - """Creates a validated `TPUEstimatorSpec` instance.""" - host_calls = {} - if eval_metrics is not None: - host_calls['eval_metrics'] = eval_metrics - if host_call is not None: - host_calls['host_call'] = host_call - _OutfeedHostCall.validate(host_calls) - - training_hooks = list(training_hooks or []) - evaluation_hooks = list(evaluation_hooks or []) - prediction_hooks = list(prediction_hooks or []) - - for hook in training_hooks + evaluation_hooks + prediction_hooks: - if not isinstance(hook, session_run_hook.SessionRunHook): - raise TypeError( - 'All hooks must be SessionRunHook instances, given: {}'.format( - hook)) - - return super(TPUEstimatorSpec, cls).__new__( - cls, - mode=mode, - predictions=predictions, - loss=loss, - train_op=train_op, - eval_metrics=eval_metrics, - export_outputs=export_outputs, - scaffold_fn=scaffold_fn, - host_call=host_call, - training_hooks=training_hooks, - evaluation_hooks=evaluation_hooks, - prediction_hooks=prediction_hooks) - - def as_estimator_spec(self): - """Creates an equivalent `EstimatorSpec` used by CPU train/eval.""" - host_calls = {} - if self.eval_metrics is not None: - host_calls['eval_metrics'] = self.eval_metrics - if self.host_call is not None: - host_calls['host_call'] = self.host_call - host_call_ret = _OutfeedHostCall.create_cpu_hostcall(host_calls) - eval_metric_ops = None - if self.eval_metrics is not None: - eval_metric_ops = host_call_ret['eval_metrics'] - hooks = None - if self.host_call is not None: - hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])] - hooks = list(hooks or []) - scaffold = self.scaffold_fn() if self.scaffold_fn else None - return model_fn_lib.EstimatorSpec( - mode=self.mode, - predictions=self.predictions, - loss=self.loss, - train_op=self.train_op, - eval_metric_ops=eval_metric_ops, - export_outputs=self.export_outputs, - scaffold=scaffold, - training_hooks=self.training_hooks + hooks, - evaluation_hooks=self.evaluation_hooks + hooks, - prediction_hooks=self.prediction_hooks + hooks) - - -class _OpQueueContext(object): - """Manages work queue and thread for a infeed/outfeed thread.""" - - def __init__(self, name, target, args): - self._name = name - self._queue = Queue.Queue() - args = (self,) + args - self._thread = threading.Thread(name=name, target=target, args=args) - self._thread.daemon = True - self._thread.start() - - def stop(self): - self._queue.put(_SIGNAL.STOP) - - def send_next_batch_signal(self, iterations): - self._queue.put(iterations) - - def read_iteration_counts(self): - while True: - iterations = self._queue.get(block=True) - logging.debug('%s read iterations %s', self._name, iterations) - if iterations == _SIGNAL.STOP: - logging.info('%s received shutdown signal, stopping.', self._name) - return - yield iterations - - def join(self): - logging.info('Shutting down %s thread.' % self._name) - self.stop() - self._thread.join() - - -class _OpSignalOnceQueueContext(_OpQueueContext): - """Manages work queue and thread for a infeed/outfeed thread. - - This subclass only signals once. - """ - - def __init__(self, name, target, args): - super(_OpSignalOnceQueueContext, self).__init__(name, target, args) - self._has_signaled = False - - def send_next_batch_signal(self, iterations): - if not self._has_signaled: - self._queue.put(iterations) - self._has_signaled = True - - -class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): - """A Session hook setting up the TPU initialization, infeed, and outfeed. - - This hook does two major things: - 1. initialize and shutdown TPU system. - 2. launch and join the threads for infeed enqueue and (optional) outfeed - dequeue. - """ - - def __init__(self, - ctx, - enqueue_ops, - dequeue_ops, - run_infeed_loop_on_coordinator=True, - rendezvous=None): - self._master_job = ctx.master_job - self._enqueue_ops = enqueue_ops - self._dequeue_ops = dequeue_ops - self._rendezvous = rendezvous - - self._run_infeed_loop_on_coordinator = run_infeed_loop_on_coordinator - self._initial_infeed_sleep_secs = ( - ctx.config.tpu_config.initial_infeed_sleep_secs) - - self._feed_error = None - self._finished = False - - def begin(self): - logging.info('TPU job name %s', self._master_job) - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - self._init_ops = [tpu.initialize_system(job=self._master_job)] - self._finalize_ops = [tpu.shutdown_system(job=self._master_job)] - - summary_writer_init_ops = contrib_summary.summary_writer_initializer_op() - self._init_ops.extend(summary_writer_init_ops) - # Get all the writer resources from the initializer, so we know what to - # flush. - for op in summary_writer_init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def _run_infeed(self, queue_ctx, session): - logging.info('Starting infeed thread controller.') - if self._initial_infeed_sleep_secs: - logging.info('%s thread sleeping for %d seconds.', self._name, - self._initial_infeed_sleep_secs) - time.sleep(self._initial_infeed_sleep_secs) - logging.info('%s thread starting after sleep', self._name) - - with self._rendezvous.catch_errors(source='infeed', session=session): - if self._run_infeed_loop_on_coordinator: - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(self._enqueue_ops) - else: - for _ in queue_ctx.read_iteration_counts(): - session.run(self._enqueue_ops) - logging.info('Infeed thread finished, shutting down.') - - def _run_outfeed(self, queue_ctx, session): - logging.info('Starting outfeed thread controller.') - with self._rendezvous.catch_errors(source='outfeed', session=session): - for count, steps in enumerate(queue_ctx.read_iteration_counts()): - for i in xrange(steps): - logging.debug('Outfeed dequeue for iteration (%d, %d)', count, i) - session.run(self._dequeue_ops) - logging.info('Outfeed thread finished, shutting down.') - - def _create_infeed_controller(self, name, target, args): - return _OpQueueContext(name=name, target=target, args=args) - - def after_create_session(self, session, coord): - logging.info('Init TPU system') - session.run(self._init_ops, - options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000)) - - self._infeed_controller = self._create_infeed_controller( - name='InfeedController', target=self._run_infeed, args=(session,)) - - self._outfeed_controller = _OpQueueContext( - name='OutfeedController', target=self._run_outfeed, args=(session,)) - - def before_run(self, run_context): - self._feed_error = None - - iterations = run_context.session.run(self._iterations_per_loop_var) - - logging.info('Enqueue next (%d) batch(es) of data to infeed.', iterations) - self._infeed_controller.send_next_batch_signal(iterations) - - logging.info('Dequeue next (%d) batch(es) of data from outfeed.', - iterations) - self._outfeed_controller.send_next_batch_signal(iterations) - - def end(self, session): - self._finished = True - logging.info('Stop infeed thread controller') - self._infeed_controller.join() - self._rendezvous.record_done('infeed') - - logging.info('Stop output thread controller') - self._outfeed_controller.join() - self._rendezvous.record_done('outfeed') - - logging.info('Shutdown TPU system.') - session.run(self._finalize_ops) - - -class TPUInfeedOutfeedSessionHookForPrediction(TPUInfeedOutfeedSessionHook): - - def __init__(self, ctx, enqueue_ops, dequeue_ops, rendezvous=None): - super(TPUInfeedOutfeedSessionHookForPrediction, self).__init__( - ctx, enqueue_ops, dequeue_ops, run_infeed_loop_on_coordinator=False, - rendezvous=rendezvous) - - def _create_infeed_controller(self, name, target, args): - return _OpSignalOnceQueueContext(name=name, target=target, args=args) - - -class _TPUStopAtStepHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step. - - This hook is similar to the `session_run_hook._StopAfterNEvalsHook` with - following differences for TPU training: - - 1. This hook sets the variable for iterations_per_loop, which is used by - `TPUInfeedOutfeedSessionHook` to control the iterations for infeed/outfeed. - As the hook execution order is not guaranteed, the variable update is - handled in `after_create_session` and `after_run` as - `TPUInfeedOutfeedSessionHook` reads the variable value in `before_run`. - - 2. For each training loop (session.run), the global step could be increased - multiple times on TPU. The global step tensor value will be explicitly read - again in `after_run` to ensure the latest value is retrieved to avoid race - condition. - """ - - def __init__(self, iterations, num_steps=None, last_step=None): - """Initializes a `StopAtStepHook`. - - Args: - iterations: The number of iterations to run optimizer per training loop. - num_steps: Number of steps to execute. - last_step: Step after which to stop. - - Raises: - ValueError: If one of the arguments is invalid. - """ - if num_steps is None and last_step is None: - raise ValueError('One of num_steps or last_step must be specified.') - if num_steps is not None and last_step is not None: - raise ValueError('Only one of num_steps or last_step can be specified.') - self._num_steps = num_steps - self._last_step = last_step - self._iterations = iterations - - def _next_iterations(self, global_step, last_step): - gap = last_step - global_step - return min(gap, self._iterations) - - def begin(self): - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError('Global step should be created.') - - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - global_step = session.run(self._global_step_tensor) - if self._last_step is None: - self._last_step = global_step + self._num_steps - - iterations = self._next_iterations(global_step, self._last_step) - - self._iterations_per_loop_var.load(iterations, session=session) - - def after_run(self, run_context, run_values): - # Global step cannot be retrieved via SessionRunArgs and before_run due to - # race condition. - global_step = run_context.session.run(self._global_step_tensor) - if global_step >= self._last_step: - run_context.request_stop() - else: - iterations = self._next_iterations(global_step, self._last_step) - self._iterations_per_loop_var.load( - iterations, session=run_context.session) - - -class _SetEvalIterationsHook(session_run_hook.SessionRunHook): - """Hook that requests stop at a specified step.""" - - def __init__(self, num_steps): - """Initializes a `_SetEvalIterationsHook`. - - Args: - num_steps: Number of steps to execute. - """ - self._num_steps = num_steps - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - self._iterations_per_loop_var.load(self._num_steps, session=session) - - -class _StoppingPredictHook(session_run_hook.SessionRunHook): - """Hook that requests stop according to the stopping signal in prediction.""" - - def __init__(self, scalar_stopping_signal): - self._scalar_stopping_signal = scalar_stopping_signal - - def begin(self): - self._iterations_per_loop_var = _create_or_get_iterations_per_loop() - - def after_create_session(self, session, coord): - # This is not necessary as we do not run infeed enqueue and outfeed dequeue - # in side threads for prediction model. But it makes the - # TPUInfeedOutfeedSessionHook prints nice message. - self._iterations_per_loop_var.load(1, session=session) - - def before_run(self, run_context): - return session_run_hook.SessionRunArgs(self._scalar_stopping_signal) - - def after_run(self, run_context, run_values): - _ = run_context - scalar_stopping_signal = run_values.results - if _StopSignals.should_stop(scalar_stopping_signal): - # NOTE(xiejw): In prediction, stopping signals are inserted for each - # batch. And we append one more batch to signal the system it should stop. - # The data flow might look like - # - # batch 0: images, labels, stop = 0 (user provided) - # batch 1: images, labels, stop = 0 (user provided) - # ... - # batch 99: images, labels, stop = 0 (user provided) - # batch 100: images, labels, stop = 1 (TPUEstimator appended) - # - # where the final batch (id = 100) is appended by TPUEstimator, so we - # should drop it before returning the predictions to user. - # To achieve that, we throw the OutOfRangeError in after_run. Once - # Monitored Session sees this error in SessionRunHook.after_run, the - # "current" prediction, i.e., batch with id=100, will be discarded - # immediately - raise errors.OutOfRangeError(None, None, 'Stopped by stopping signal.') - - -def generate_per_core_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, host_device, host_id): - """Generates infeed enqueue ops for per-core input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A fn returns enqueue_ops.""" - num_cores_per_host = ctx.num_of_cores_per_host - per_host_sharded_inputs = [] - for core_ordinal in range(num_cores_per_host): - with ops.name_scope('ordinal_%d' % (core_ordinal)): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, - input_device=host_device, - invocation_index=host_id * ctx.num_of_cores_per_host + core_ordinal - ) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - if inputs.is_dataset: - raise TypeError( - '`input_fn` returning `Dataset` is not yet supported in ' - 'per-Core input pipeline deployment yet. Please set ' - 'TPUConfig.per_host_input_for_training to True or return ' - '`features` and `labels` from `input_fn`') - features, labels = inputs.features_and_labels() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels)) - per_host_sharded_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, tpu_ordinal_function=tpu_ordinal_function_impl) - return per_host_enqueue_ops - - return enqueue_ops_fn, captured_infeed_queue - - -def generate_per_host_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, batch_axis, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - - hooks = [] - - with ops.DEVICE(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, - input_device=device, - invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - if batch_axis is not None: - raise TypeError('For mode PREDICT, batch_axis is not supported yet.') - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - hooks.append(inputs.dataset_initializer_hook()) - - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """A Fn returning the TPU infeed enqueue ops. - - By providing as a Fn, it can be invoked inside the tf.while_loop such that - the input pipeline for multiple iterations can be executed by one - Session.run call. - - Returns: - list of dict of ops. - """ - with ops.DEVICE(device): - num_of_replicas_per_host = ctx.num_of_replicas_per_host - # Convert user input to features and labels. If the user returns a - # dataset, it is initialized and the features and labels extracted via - # `dataset.iterator.get_next()` - features, labels = inputs.features_and_labels() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure(features, labels) - unsharded_tensor_list = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_tensor_list], - tuple_shapes=[t.shape for t in unsharded_tensor_list], - shard_dimensions=batch_axis) - captured_infeed_queue.capture(infeed_queue) - infeed_queue.set_number_of_shards(num_of_replicas_per_host) - per_host_enqueue_ops = ( - infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_tensor_list, - placement_function=lambda x: device, - tpu_ordinal_function=tpu_ordinal_function_impl)) - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset - - -def generate_per_host_v2_enqueue_ops_fn_for_host( - ctx, input_fn, inputs_structure_recorder, device, host_id): - """Generates infeed enqueue ops for per-host input_fn on a single host.""" - captured_infeed_queue = _CapturedObject() - hooks = [] - - with ops.DEVICE(device): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, - input_device=device, - invocation_index=host_id) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if not is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for the PER_HOST_V2 ' - 'input pipeline configuration.') - - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True, - num_invocations_per_step=ctx.num_of_replicas_per_host) - - hooks.append(inputs.dataset_initializer_hook()) - tpu_ordinal_function_impl = ctx.tpu_ordinal_function(host_id) - - def enqueue_ops_fn(): - """Generates the per_host enqueue ops.""" - control_deps = [] - per_host_sharded_inputs = [] - num_replicas_per_host = ctx.num_of_replicas_per_host - cached_signals = None - with ops.DEVICE(device): - if not inputs.is_dataset: - raise TypeError('`input_fn` must return a `Dataset` for this mode.') - for _ in range(num_replicas_per_host): - # Use control dependencies to ensure a deterministic ordering. - with ops.control_dependencies(control_deps): - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - # All the replicas share the replica 0's stopping singal. - # This avoids inconsistent state among different model replcias. - if cached_signals: - signals['stopping'] = cached_signals['stopping'] - else: - cached_signals = signals - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - control_deps.extend(flattened_inputs) - per_host_sharded_inputs.append(flattened_inputs) - - if inputs_structure_recorder.flattened_input_dims: - input_partition_dims = inputs_structure_recorder.flattened_input_dims - if signals: - input_partition_dims += [None] * len(signals) - # pylint: disable=protected-access - infeed_queue = tpu_feed._PartitionedInfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0]), - host_id=host_id, - input_partition_dims=input_partition_dims, - device_assignment=ctx.device_assignment) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs) - else: - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(per_host_sharded_inputs[0])) - per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( - per_host_sharded_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl) - captured_infeed_queue.capture(infeed_queue) - - if signals is None: - return per_host_enqueue_ops - else: - return { - 'ops': per_host_enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset - - -def generate_broadcast_enqueue_ops_fn(ctx, input_fn, inputs_structure_recorder, - num_hosts): - """Generates infeed enqueue ops for one input_fn on all the hosts.""" - captured_infeed_queue = _CapturedObject() - hooks = [] - device_0 = ctx.tpu_host_placement_function(host_id=0) - with ops.DEVICE(device_0): - user_context = tpu_context.TPUContext( - internal_ctx=ctx, input_device=device_0, invocation_index=0) - inputs = _Inputs.from_input_fn(input_fn(user_context)) - - is_dataset = inputs.is_dataset - if ctx.mode == model_fn_lib.ModeKeys.PREDICT: - if not is_dataset: - raise TypeError( - 'For mode PREDICT, `input_fn` must return `Dataset` instead of ' - '`features` and `labels`.') - - inputs = _InputsWithStoppingSignals( - dataset=inputs.dataset, - batch_size=ctx.batch_size_for_input_fn, - add_padding=True) - - if is_dataset: - hooks.append(inputs.dataset_initializer_hook()) - num_replicas_per_host = ctx.num_of_replicas_per_host - - def tpu_ordinal_function_impl(replica_id): - if ctx.device_assignment: - return ctx.device_assignment.tpu_ordinal(replica=replica_id) - else: - return replica_id % num_replicas_per_host - - def device_function_impl(replica_id): - return ctx.tpu_host_placement_function(replica_id=replica_id) - - def enqueue_ops_fn(): - """Generates enqueue ops for all the hosts.""" - broadcasted_inputs = [] - flattened_inputs = None # Cache result from input_fn. - signals = None - for host_id in xrange(num_hosts): - with ops.DEVICE(ctx.tpu_host_placement_function(host_id=host_id)): - for _ in xrange(ctx.num_of_replicas_per_host): - # Note: input_fn is only called once at host 0 for the first replica. - # The features and labels returned from that invocation are - # broadcasted to other replicas(including the replicas on other - # hosts). - if flattened_inputs is None: - features, labels = inputs.features_and_labels() # Calls get_next() - signals = inputs.signals() - - inputs_structure_recorder.validate_and_record_structure( - features, labels) - flattened_inputs = ( - inputs_structure_recorder.flatten_features_and_labels( - features, labels, signals)) - broadcasted_inputs.append(flattened_inputs) - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(broadcasted_inputs[0])) - captured_infeed_queue.capture(infeed_queue) - enqueue_ops = infeed_queue.generate_enqueue_ops( - broadcasted_inputs, - tpu_ordinal_function=tpu_ordinal_function_impl, - placement_function=device_function_impl) - - if signals is None: - return enqueue_ops - else: - return { - 'ops': enqueue_ops, - 'signals': signals, - } - - return enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset - - -class _InputPipeline(object): - """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. - - `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from - call site. To be precise, based on the configuration in - `_InternalTPUContext`, it invokes `input_fn` for all cores (usually - multi-host TPU training) or for one host (usually for single-host TPU - evaluation), and sends all `features` and `labels` returned by `input_fn` to - TPU infeed. For per-core invocation, `features` and `labels` are piped to - infeed directly, one tuple for each core. For per-host invocation, `features` - and `labels` are split at host (with respect to `batch_axis`) and piped to all - cores accordingly. - - In addition, flatten/unflatten are handled by `_InputPipeline` also. Model - inputs returned by the `input_fn` can have one of the following forms: - 1. features - 2. (features, labels) - 3. ((arbitrarily nested structure of features), labels) - - Internally, form 1 is reformed to `(features, None)` as features and labels - are passed separately to underlying methods. For TPU training, TPUEstimator - may expect multiple `features` and `labels` tuples one for each core. - - TPUEstimator allows various different structures for inputs (namely `features` - and `labels`). `features` can be `Tensor`, dict of string name to `Tensor`, - or nested tuples and `labels` could be `None`, `Tensor`, or dict of string - name to `Tensor`. TPU infeed/outfeed library expects flattened tensor list. - So, `features` and `labels` need to be flattened, before infeed enqueue, and - the structure of them needs to be recorded, in order to restore them after - infeed dequeue. - """ - - class InputsStructureRecorder(object): - """The recorder to record inputs structure.""" - - def __init__(self, input_partition_dims=None): - # Holds the structure of inputs - self._feature_structure = {} - self._flattened_input_dims = None - - if input_partition_dims: - # This should have been validated in TPUConfig. - assert len(input_partition_dims) <= 2, 'must have 1 or 2 elements.' - if len(input_partition_dims) == 2: - self._feature_dims, self._label_dims = input_partition_dims - else: - self._feature_dims = input_partition_dims[0] - self._label_dims = None - - assert self._feature_dims is not None, ('input_partition_dims[0] must ' - 'not be None') - else: - self._feature_dims = None - self._label_dims = None - - # Internal state. - self._initialized = False - - @property - def flattened_input_dims(self): - assert self._initialized, 'InputsStructureRecorder is not initialized.' - return self._flattened_input_dims - - def has_labels(self): - return 'labels' in self._feature_structure - - def _flatten_input_dims(self, feature_dims, feature_dims_names, label_dims, - label_dims_names, label_names, has_labels): - """Flatten input dims with the same order as flattened input tensors.""" - flattened_input_dims = [] - if feature_dims_names: - # We need a fixed ordering for matching the tensors in features. - flattened_input_dims.extend( - [feature_dims[name] for name in feature_dims_names]) - else: - flattened_input_dims.append(feature_dims) - - if label_dims_names: - # We need a fixed ordering for matching the tensors in labels. - flattened_input_dims.extend( - [label_dims[name] for name in label_dims_names]) - else: - if label_names: - num_tensors_in_label = len(label_names) - else: - num_tensors_in_label = int(has_labels) - # Setting `None` in input_partition_dims[1] will apply `None` to - # all the tensors in labels, regardless of internal structure. - flattened_input_dims.extend([label_dims] * num_tensors_in_label) - - return flattened_input_dims - - def validate_and_record_structure(self, features, labels): - """Validates and records the structure of `features` and `labels`.""" - # Extract structure. - has_labels = labels is not None - feature_names = _extract_key_names(features) - label_names = _extract_key_names(labels) - - if not self._initialized: - # Record structure. - self._initialized = True - if self._feature_dims is not None: - feature_dims_names = _extract_key_names(self._feature_dims) - if feature_dims_names != feature_names: - raise ValueError( - 'TPUConfig.input_partition_dims[0] mismatched feature' - ' keys. Expected {}, got {}'.format(feature_names, - feature_dims_names)) - - label_dims_names = _extract_key_names(self._label_dims) - if self._label_dims is not None and label_dims_names != label_names: - raise ValueError( - 'TPUConfig.input_partition_dims[1] mismatched label' - ' keys. Expected {}, got {}'.format(label_names, - label_dims_names)) - - self._flattened_input_dims = self._flatten_input_dims( - self._feature_dims, feature_dims_names, self._label_dims, - label_dims_names, label_names, has_labels) - - def flatten_features_and_labels(self, features, labels, signals=None): - """Flattens the `features` and `labels` to a single tensor list.""" - self._feature_structure['features'] = features - if labels is not None: - self._feature_structure['labels'] = labels - if signals is not None: - self._feature_structure['signals'] = signals - return data_nest.flatten(self._feature_structure) - - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. - - Args: - flattened_inputs: Flattened inputs for each shard. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. - - Raises: - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. - """ - - unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure, - flattened_inputs) - return _Inputs( - unflattened_inputs['features'], - unflattened_inputs.get('labels'), - signals=unflattened_inputs.get('signals')) - - def __init__(self, input_fn, batch_axis, ctx): - """Constructor. - - Args: - input_fn: input fn for train or eval. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. - ctx: A `_InternalTPUContext` instance with mode. - - Raises: - ValueError: If both `sharded_features` and `num_cores` are `None`. - """ - self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder( - ctx.input_partition_dims) - - self._sharded_per_core = ctx.is_input_sharded_per_core() - self._input_fn = input_fn - self._infeed_queue = None - self._ctx = ctx - self._batch_axis = batch_axis - - def generate_infeed_enqueue_ops_and_dequeue_fn(self): - """Generates infeed enqueue ops and dequeue_fn.""" - # While tf.while_loop is called, the body function, which invokes - # `enqueue_fn` passed in, is called to construct the graph. So, input_fn - # structure is recorded. - enqueue_ops, all_hooks, run_infeed_loop_on_coordinator = ( - self._invoke_input_fn_and_record_structure()) - - self._validate_input_pipeline() - - def dequeue_fn(): - """dequeue_fn is used by TPU to retrieve the tensors.""" - # In the model-parallel case, both the host-side and DEVICE-side - # computations must agree on the core on which infeed takes place. We - # choose to perform infeed on logical core 0 of each replica. - values = self._infeed_queue.generate_dequeue_op(tpu_device=0) - # The unflatten process uses the structure information recorded above. - return self._inputs_structure_recorder.unflatten_features_and_labels( - values) - - return (enqueue_ops, dequeue_fn, all_hooks, run_infeed_loop_on_coordinator) - - def _invoke_input_fn_and_record_structure(self): - """Deploys the input pipeline and record input structure.""" - enqueue_ops = [] - infeed_queues = [] - all_hooks = [] - num_hosts = self._ctx.num_hosts - tpu_host_placement_fn = self._ctx.tpu_host_placement_function - - run_infeed_loop_on_coordinator = True - - if self._sharded_per_core: - # Per-Core input pipeline deployment. - # Invoke input pipeline for each core and placed on the corresponding - # host. - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.DEVICE(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - enqueue_ops_fn, captured_infeed_queue = ( - generate_per_core_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, self._inputs_structure_recorder, - host_device, host_id)) - - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - run_infeed_loop_on_coordinator = False - enqueue_ops.append( - _wrap_computation_in_while_loop( - device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - # Infeed_queue_getter must be called after enqueue_ops_fn is called. - infeed_queues.append(captured_infeed_queue.get()) - - elif self._ctx.is_input_broadcast_with_iterators(): - # Only calls input_fn in host 0. - host_device = tpu_host_placement_fn(host_id=0) - enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( - generate_broadcast_enqueue_ops_fn(self._ctx, self._input_fn, - self._inputs_structure_recorder, - num_hosts)) - all_hooks.extend(hooks) - if is_dataset: - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append(wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - else: - for host_id in range(num_hosts): - host_device = tpu_host_placement_fn(host_id=host_id) - with ops.DEVICE(host_device): - with ops.name_scope('input_pipeline_task%d' % (host_id)): - if self._ctx.is_input_per_host_with_iterators(): - enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( - generate_per_host_v2_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, host_device, host_id)) - else: - enqueue_ops_fn, captured_infeed_queue, hooks, is_dataset = ( - generate_per_host_enqueue_ops_fn_for_host( - self._ctx, self._input_fn, - self._inputs_structure_recorder, self._batch_axis, - host_device, host_id)) - all_hooks.extend(hooks) - - # NOTE(xiejw): We dispatch here based on the return type of the - # users `input_fn`. - # - # 1. If input_fn returns a Dataset instance, we initialize the - # iterator outside of tf.while_loop, and call the iterator.get_next - # inside tf.while_loop. This should be always safe. - # - # 2. If input_fn returns (features, labels), it is too late to wrap - # them inside tf.while_loop, as resource initialization cannot be - # handled in TF control flow properly. In this case, we will use - # python loop to enqueue the data into TPU system. This may be - # slow compared to the previous case. - if is_dataset: - run_infeed_loop_on_coordinator = False - wrap_fn = ( - _wrap_computation_in_while_loop - if self._ctx.mode != model_fn_lib.ModeKeys.PREDICT else - _wrap_computation_in_while_loop_with_stopping_signals) - enqueue_ops.append( - wrap_fn(device=host_device, op_fn=enqueue_ops_fn)) - else: - enqueue_ops.append(enqueue_ops_fn()) - infeed_queues.append(captured_infeed_queue.get()) - # infeed_queue is used to generate dequeue ops. The only thing it uses for - # dequeue is dtypes and types. So, any one can be used. Here, grab the - # first one. - self._infeed_queue = infeed_queues[0] - return enqueue_ops, all_hooks, run_infeed_loop_on_coordinator - - def _validate_input_pipeline(self): - """Validates the input pipeline. - - Perform some sanity checks to log user friendly information. We should - error out to give users better error message. But, if - _WRAP_INPUT_FN_INTO_WHILE_LOOP is False (legacy behavior), we cannot break - user code, so, log a warning. - - Raises: - RuntimeError: If the validation failed. - """ - if ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS): - err_msg = ('Input pipeline contains one or more QueueRunners. ' - 'It could be slow and not scalable. Please consider ' - 'converting your input pipeline to use `tf.data` instead (see ' - 'https://www.tensorflow.org/guide/datasets for ' - 'instructions.') - if _WRAP_INPUT_FN_INTO_WHILE_LOOP: - raise RuntimeError(err_msg) - else: - logging.warn(err_msg) - - -class _ModelFnWrapper(object): - """A `model_fn` wrapper. - - This makes calling model_fn on CPU and TPU easier and more consistent and - performs necessary check and mutation required by TPU training and evaluation. - - In addition, this wrapper manages converting the `model_fn` to a single TPU - train and eval step. - """ - - def __init__(self, model_fn, train_cache_fn, eval_cache_fn, config, params, ctx): - self._model_fn = model_fn - self._train_cache_fn = train_cache_fn - self._eval_cache_fn = eval_cache_fn - self._config = config - self._params = params - self._ctx = ctx - - def call_without_tpu(self, features, labels, is_export_mode): - return self._call_model_fn(features, labels, is_export_mode=is_export_mode) - - def convert_to_single_tpu_train_step(self, dequeue_fn): - """Converts user provided model_fn` as a single train step on TPU. - - The user provided `model_fn` takes input tuple - (features, labels) and produces the EstimatorSpec with train_op and loss for - train `mode`. This usually represents a single train computation on CPU. - - For TPU training, a train (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input should be taken from TPU infeed rather - than input pipeline (input_fn) directly. To fit TPU loop and replicate - pattern, the original train computation should be reformed, which is the - returned `train_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of train_fn, host_calls, and captured scaffold_fn. The train_fn - representing the train step for TPU. - """ - - host_call = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_training_hooks = _CapturedObject() - - def train_step(loss, *cache): - """Training step function for use inside a while loop.""" - if not self._params.get('track_mean', False): - del loss # unused; required in function signature. - - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - - # Consume the current cache - estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels, cache=cache)) - - # Retrieve the new returned cache - """ - `cache` consists of a list of tensors, potentially empty (of length 0) - """ - cache = estimator_spec.cache - new_loss, train_op = estimator_spec.loss, estimator_spec.train_op - - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - captured_scaffold_fn.capture(estimator_spec.scaffold_fn) - else: - captured_scaffold_fn.capture(None) - - captured_training_hooks.capture(estimator_spec.training_hooks) - - # We must run train_op to update the variables prior to running the - # outfeed. - with ops.control_dependencies([train_op]): - host_call_outfeed_ops = [] - if (isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec) # pylint: disable=protected-access - and estimator_spec.host_call is not None): - host_call.record({'host_call': estimator_spec.host_call}) - host_call_outfeed_ops = host_call.create_enqueue_op() - with ops.control_dependencies(host_call_outfeed_ops): - if self._params.get('track_mean', False): - loss = tensorflow.stop_gradient(loss) - return [math_ops.add(loss, new_loss)] + cache - else: - return [array_ops.identity(new_loss)] + cache - - return (train_step, host_call, captured_scaffold_fn, - captured_training_hooks) - - def convert_to_single_tpu_eval_step(self, dequeue_fn): - """Converts user provided model_fn` as a single eval step on TPU. - - Similar to training, the user provided `model_fn` takes input tuple - (features, labels) and produces the TPUEstimatorSpec with eval_metrics for - eval `mode`. This usually represents a single evaluation computation on CPU. - - For TPU evaluation, a eval (computation) step is first wrapped in a - tf.while_loop control flow to repeat for many times and then replicated to - all TPU shards. Besides the input and output are slightly different. Input, - features and labels, should be taken from TPU infeed rather than input - pipeline (input_fn) directly. Output is managed in two stages. First, the - model outputs as the result of evaluation computation, usually model logits, - should be transferred from TPU system to CPU. Then, all model outputs are - concatenated first on CPU and sent to the metric_fn for metrics computation. - To fit TPU evaluation pattern, the original eval computation should be - reformed, which is the returned `eval_step`. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of eval_fn, host_calls, and captured scaffold_fn. The eval_fn - representing the eval step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_eval_hooks = _CapturedObject() - - def eval_step(total_loss, *cache): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - - # Consume the current cache - tpu_estimator_spec = self._call_model_fn(features, labels, cache=cache) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU evaluation must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - # Retrieve the new returned cache - cache = tpu_estimator_spec.cache - loss = tpu_estimator_spec.loss - - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_eval_hooks.capture(tpu_estimator_spec.evaluation_hooks) - - to_record = {} - if tpu_estimator_spec.eval_metrics: - to_record['eval_metrics'] = tpu_estimator_spec.eval_metrics - if tpu_estimator_spec.host_call is not None: - # We assume that evaluate won't update global step, so we don't wrap - # this host_call. - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return [math_ops.add(total_loss, loss)] + cache - - return eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - - def convert_to_single_tpu_predict_step(self, dequeue_fn): - """Converts user provided model_fn` as a single predict step on TPU. - - Args: - dequeue_fn: The function to retrieve inputs, features and labels, from TPU - infeed dequeue channel. - - Returns: - A tuple of predict_fn, host_calls, and captured scaffold_fn. The - predict_fn representing the predict step for TPU. - """ - host_calls = _OutfeedHostCall(self._ctx) - captured_scaffold_fn = _CapturedObject() - captured_predict_hooks = _CapturedObject() - - def predict_step(unused_scalar_stopping_signal): - """Evaluation step function for use inside a while loop.""" - inputs = dequeue_fn() - features, labels = inputs.features_and_labels() - stopping_signals = inputs.signals() - - assert stopping_signals is not None, ( - 'Internal Error: `signals` is missing.') - - tpu_estimator_spec = self._call_model_fn( - features, labels, is_export_mode=False) - if not isinstance(tpu_estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - raise RuntimeError( - 'estimator_spec used by TPU prediction must have type' - '`TPUEstimatorSpec`. Got {}'.format(type(tpu_estimator_spec))) - - self._verify_tpu_spec_predictions(tpu_estimator_spec.predictions) - - captured_scaffold_fn.capture(tpu_estimator_spec.scaffold_fn) - captured_predict_hooks.capture(tpu_estimator_spec.prediction_hooks) - to_record = {} - identity_fn = lambda **kwargs: kwargs - to_record['predictions'] = [identity_fn, tpu_estimator_spec.predictions] - to_record['signals'] = [identity_fn, stopping_signals] - if tpu_estimator_spec.host_call is not None: - to_record['host_call'] = tpu_estimator_spec.host_call - host_calls.record(to_record) - - with ops.control_dependencies(host_calls.create_enqueue_op()): - return _StopSignals.as_scalar_stopping_signal(stopping_signals) - - return (predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks) - - def _verify_tpu_spec_predictions(self, predictions): - """Validates TPUEstimatorSpec.predictions dict.""" - # TODO(xiejw): Adds validation for prediction dictionrary. - # TODO(xiejw): Adds support for single tensor as predictions. - if not isinstance(predictions, dict): - raise TypeError('TPUEstimatorSpec.predictions must be dict of Tensors.') - - for (key, tensor) in predictions.items(): - if tensor.shape[0].value is None: - raise ValueError( - 'The tensor with key ({}) in TPUEstimatorSpec.predictions has ' - 'dynamic shape (should be static). Tensor: {}'.format( - key, tensor)) - return predictions - - def _validate_model_features_and_labels(self, - features, - labels, - is_export_mode): - """Validates that the features and labels for the model function are valid. - - A valid features/labels object is the one with: - - Type: Tensor or a dictionary of Tensors - - Static shape if is_export_mode is False. - - Args: - features: the features that would be input to the model function. - labels: the labels that would be input to the model function. - is_export_mode: boolean value specifying if in export mode. - - Raises: - TypeError: If features/labels are not of the correct type. - ValueError: If features/labels have dynamic shape. - """ - - def validate(obj, obj_name): - """Helper validate function.""" - if not isinstance(obj, ops.Tensor) and not isinstance(obj, dict): - raise TypeError( - 'The {} to the model returned by input_fn must be either a Tensor ' - 'or a dictionary of Tensors. {}: {}'.format(obj_name, obj_name, - obj)) - if is_export_mode or self._ctx.is_running_on_cpu(is_export_mode): - return - if isinstance(obj, ops.Tensor): - if not obj.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static shape.' - ' Tensor: {}'.format(obj_name, obj)) - else: - for (key, value) in obj.items(): - flattened_tensors = data_nest.flatten(value) - for tensor in flattened_tensors: - if not tensor.get_shape().is_fully_defined(): - raise ValueError( - 'The {} to the model returned by input_fn must have static ' - 'shape. Key: \'{}\', Tensor: {}'.format( - obj_name, key, tensor)) - - validate(features, 'features') - if labels is not None: - validate(labels, 'labels') - - def _call_model_fn(self, features, labels, cache=None, is_export_mode=False): - """Calls the model_fn with required parameters.""" - self._validate_model_features_and_labels(features, labels, is_export_mode) - model_fn_args = function_utils.fn_args(self._model_fn) - kwargs = {} - - # Makes deep copy with `config` and params` in case user mutates them. - config = copy.deepcopy(self._config) - params = copy.deepcopy(self._params) - - if 'labels' in model_fn_args: - kwargs['labels'] = labels - elif labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') - if 'mode' in model_fn_args: - kwargs['mode'] = self._ctx.mode - if 'config' in model_fn_args: - kwargs['config'] = config - if 'params' in model_fn_args: - kwargs['params'] = params - - if cache is not None: - params['cache'] = cache - - if 'params' not in model_fn_args: - raise ValueError('model_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params[\'batch_size\']'.format(self._model_fn)) - - if is_export_mode: - batch_size_for_model_fn = None - else: - batch_size_for_model_fn = self._ctx.batch_size_for_model_fn - - if batch_size_for_model_fn is not None: - _add_item_to_params(params, _BATCH_SIZE_KEY, batch_size_for_model_fn) - - running_on_cpu = self._ctx.is_running_on_cpu(is_export_mode) - _add_item_to_params(params, _USE_TPU_KEY, not running_on_cpu) - - if not running_on_cpu: - user_context = tpu_context.TPUContext( - internal_ctx=self._ctx, call_from_input_fn=False) - _add_item_to_params(params, _CTX_KEY, user_context) - - estimator_spec = self._model_fn(features=features, **kwargs) - if (running_on_cpu and - isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec)): # pylint: disable=protected-access - # The estimator_spec will be passed to `Estimator` directly, which expects - # type `EstimatorSpec`. - return estimator_spec.as_estimator_spec() - else: - return estimator_spec - - def _verify_estimator_spec(self, estimator_spec): - """Validates the estimator_spec.""" - if isinstance(estimator_spec, model_fn_lib._TPUEstimatorSpec): # pylint: disable=protected-access - return estimator_spec - - err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' - if estimator_spec.training_chief_hooks: - raise ValueError( - err_msg.format('training_chief_hooks') + 'If you want' + - ' to pass training hooks, please pass via training_hooks.') - - if estimator_spec.scaffold: - logging.warning('EstimatorSpec.Scaffold is ignored by TPU train/eval. ' - 'Please use TPUEstimatorSpec.') - return estimator_spec - - -class _OutfeedHostCall(object): - """Support for `eval_metrics` and `host_call` in TPUEstimatorSpec.""" - - def __init__(self, ctx): - self._ctx = ctx - self._names = [] - # All of these are dictionaries of lists keyed on the name. - self._host_fns = {} - self._tensor_keys = collections.defaultdict(list) - self._tensors = collections.defaultdict(list) - self._tensor_dtypes = collections.defaultdict(list) - self._tensor_shapes = collections.defaultdict(list) - - @staticmethod - def validate(host_calls): - """Validates the `eval_metrics` and `host_call` in `TPUEstimatorSpec`.""" - - for name, host_call in host_calls.items(): - if not isinstance(host_call, (tuple, list)): - raise ValueError('{} should be tuple or list'.format(name)) - if len(host_call) != 2: - raise ValueError('{} should have two elements.'.format(name)) - if not callable(host_call[0]): - raise TypeError('{}[0] should be callable.'.format(name)) - if not isinstance(host_call[1], (tuple, list, dict)): - raise ValueError('{}[1] should be tuple or list, or dict.'.format(name)) - - if isinstance(host_call[1], (tuple, list)): - fullargspec = tf_inspect.getfullargspec(host_call[0]) - fn_args = function_utils.fn_args(host_call[0]) - # wrapped_hostcall_with_global_step uses varargs, so we allow that. - if fullargspec.varargs is None and len(host_call[1]) != len(fn_args): - raise RuntimeError( - 'In TPUEstimatorSpec.{}, length of tensors {} does not match ' - 'method args of the function, which takes {}.'.format( - name, len(host_call[1]), len(fn_args))) - - @staticmethod - def create_cpu_hostcall(host_calls): - """Runs on the host_call on CPU instead of TPU when use_tpu=False.""" - - _OutfeedHostCall.validate(host_calls) - ret = {} - for name, host_call in host_calls.items(): - host_fn, tensors = host_call - if isinstance(tensors, (tuple, list)): - ret[name] = host_fn(*tensors) - else: - # Must be dict. - try: - ret[name] = host_fn(**tensors) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise e - return ret - - def record(self, host_calls): - """Records the host_call structure.""" - - for name, host_call in host_calls.items(): - host_fn, tensor_list_or_dict = host_call - self._names.append(name) - self._host_fns[name] = host_fn - - if isinstance(tensor_list_or_dict, dict): - for (key, tensor) in six.iteritems(tensor_list_or_dict): - self._tensor_keys[name].append(key) - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - else: - # List or tuple. - self._tensor_keys[name] = None - for tensor in tensor_list_or_dict: - self._tensors[name].append(tensor) - self._tensor_dtypes[name].append(tensor.dtype) - self._tensor_shapes[name].append(tensor.shape) - - def create_enqueue_op(self): - """Create the op to enqueue the recorded host_calls. - - Returns: - A list of enqueue ops, which is empty if there are no host calls. - """ - if not self._names: - return [] - - tensors = [] - # TODO(jhseu): Consider deduping tensors. - for name in self._names: - tensors.extend(self._tensors[name]) - - with ops.DEVICE(tpu.core(0)): - return [tpu_ops.outfeed_enqueue_tuple(tensors)] - - def create_tpu_hostcall(self): - """Sends the tensors through outfeed and runs the host_fn on CPU. - - The tensors are concatenated along dimension 0 to form a global tensor - across all shards. The concatenated function is passed to the host_fn and - executed on the first host. - - Returns: - A dictionary mapping name to the return type of the host_call by that - name. - - Raises: - RuntimeError: If outfeed tensor is scalar. - """ - if not self._names: - return {} - - ret = {} - # For each i, dequeue_ops[i] is a list containing the tensors from all - # shards. This list is concatenated later. - dequeue_ops = [] - tensor_dtypes = [] - tensor_shapes = [] - for name in self._names: - for _ in self._tensors[name]: - dequeue_ops.append([]) - for dtype in self._tensor_dtypes[name]: - tensor_dtypes.append(dtype) - for shape in self._tensor_shapes[name]: - tensor_shapes.append(shape) - - # Outfeed ops execute on each replica's first logical core. Note: we must - # constraint it such that we have at most one outfeed dequeue and enqueue - # per replica. - for i in xrange(self._ctx.num_replicas): - host_device, ordinal_id = self._ctx.device_for_replica(i) - with ops.DEVICE(host_device): - outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( - dtypes=tensor_dtypes, - shapes=tensor_shapes, - device_ordinal=ordinal_id) - for j, item in enumerate(outfeed_tensors): - dequeue_ops[j].append(item) - - # Deconstruct dequeue ops. - dequeue_ops_by_name = {} - pos = 0 - for name in self._names: - dequeue_ops_by_name[name] = dequeue_ops[pos:pos+len(self._tensors[name])] - pos += len(self._tensors[name]) - - # It is assumed evaluation always happens on single host TPU system. So, - # place all ops on tpu host if possible. - # - # TODO(jhseu): Evaluate whether this is right for summaries. - with ops.DEVICE(self._ctx.tpu_host_placement_function(replica_id=0)): - for name in self._names: - dequeue_ops = dequeue_ops_by_name[name] - for i, item in enumerate(dequeue_ops): - if dequeue_ops[i][0].shape.ndims == 0: - raise RuntimeError( - 'All tensors outfed from TPU should preserve batch size ' - 'dimension, but got scalar {}'.format(dequeue_ops[i][0])) - # TODO(xiejw): Allow users to specify the axis for batch size - # dimension. - dequeue_ops[i] = array_ops.concat(dequeue_ops[i], axis=0) - - if self._tensor_keys[name] is not None: - # The user-provided eval_metrics[1] is a dict. - dequeue_ops = dict(zip(self._tensor_keys[name], dequeue_ops)) - try: - ret[name] = self._host_fns[name](**dequeue_ops) - except TypeError as e: - logging.warning( - 'Exception while calling %s: %s. It is likely the tensors ' - '(%s[1]) do not match the ' - 'function\'s arguments', name, e, name) - raise e - else: - ret[name] = self._host_fns[name](*dequeue_ops) - - return ret - - -class _OutfeedHostCallHook(session_run_hook.SessionRunHook): - """Hook to run host calls when use_tpu=False.""" - - def __init__(self, tensors): - self._tensors = tensors - - def begin(self): - # We duplicate this code from the TPUInfeedOutfeedSessionHook rather than - # create a separate hook to guarantee execution order, because summaries - # need to be initialized before the outfeed thread starts. - # TODO(jhseu): Make a wrapper hook instead? - self._init_ops = contrib_summary.summary_writer_initializer_op() - # Get all the writer resources from the initializer, so we know what to - # flush. - self._finalize_ops = [] - for op in self._init_ops: - self._finalize_ops.append(contrib_summary.flush(writer=op.inputs[0])) - - def after_create_session(self, session, coord): - session.run(self._init_ops) - - def before_run(self, run_context): - return basic_session_run_hooks.SessionRunArgs(self._tensors) - - def end(self, session): - session.run(self._finalize_ops) - - -class ExamplesPerSecondHook(basic_session_run_hooks.StepCounterHook): - """Calculate and report global_step/sec and examples/sec during runtime.""" - - def __init__(self, - batch_size, - every_n_steps=100, - every_n_secs=None, - output_dir=None, - summary_writer=None): - self._batch_size = batch_size - super(ExamplesPerSecondHook, self).__init__( - every_n_steps=every_n_steps, - every_n_secs=every_n_secs, - output_dir=output_dir, - summary_writer=summary_writer) - - def _log_and_record(self, elapsed_steps, elapsed_time, global_step): - global_step_per_sec = elapsed_steps / elapsed_time - examples_per_sec = self._batch_size * global_step_per_sec - if self._summary_writer is not None: - global_step_summary = Summary(value=[ - Summary.Value(tag='global_step/sec', simple_value=global_step_per_sec) - ]) - example_summary = Summary(value=[ - Summary.Value(tag='examples/sec', simple_value=examples_per_sec) - ]) - self._summary_writer.add_summary(global_step_summary, global_step) - self._summary_writer.add_summary(example_summary, global_step) - logging.info('global_step/sec: %g', global_step_per_sec) - logging.info('examples/sec: %g', examples_per_sec) - - -class InstallSignalHandlerHook(session_run_hook.SessionRunHook): - """Change SIGINT (CTRL^C) handler to force quit the process. - - The default behavior often results in hanging processes. - The original handler is restored after training/evaluation. - """ - - def __init__(self): - self._signal_fn = signal.getsignal(signal.SIGINT) - - def before_run(self, run_context): - signal.signal(signal.SIGINT, signal.SIG_DFL) - - def end(self, session): - signal.signal(signal.SIGINT, self._signal_fn) - - -class TPUEstimator(estimator_lib.Estimator): - """Estimator with TPU support. - - TPUEstimator also supports training on CPU and GPU. You don't need to define - a separate `tf.estimator.Estimator`. - - TPUEstimator handles many of the details of running on TPU devices, such as - replicating inputs and models for each core, and returning to host - periodically to run hooks. - - TPUEstimator transforms a global batch size in params to a per-shard batch - size when calling the `input_fn` and `model_fn`. Users should specify - global batch size in constructor, and then get the batch size for each shard - in `input_fn` and `model_fn` by `params['batch_size']`. - - - For training, `model_fn` gets per-core batch size; `input_fn` may get - per-core or per-host batch size depending on `per_host_input_for_training` - in `TPUConfig` (See docstring for TPUConfig for details). - - - For evaluation and prediction, `model_fn` gets per-core batch size and - `input_fn` get per-host batch size. - - Evaluation - ========== - - `model_fn` should return `TPUEstimatorSpec`, which expects the `eval_metrics` - for TPU evaluation. However, if eval_on_tpu is False, `model_fn` must return - `EstimatorSpec` and the evaluation will execute on CPU or GPU; in this case - the following discussion on TPU evaluation does not apply. - - `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`, where - `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. (See - `TPUEstimatorSpec` for details). `metric_fn` takes the `tensors` and returns - a dict from metric string name to the result of calling a metric function, - namely a `(metric_tensor, update_op)` tuple. - - One can set `use_tpu` to `False` for testing. All training, evaluation, and - predict will be executed on CPU. `input_fn` and `model_fn` will receive - `train_batch_size` or `eval_batch_size` unmodified as `params['batch_size']`. - - Current limitations: - -------------------- - - 1. TPU evaluation only works on a single host (one TPU worker) except - BROADCAST mode. - - 2. `input_fn` for evaluation should **NOT** raise an end-of-input exception - (`OutOfRangeError` or `StopIteration`). And all evaluation steps and all - batches should have the same size. - - Example (MNIST): - ---------------- - - ``` - # The metric Fn which runs on CPU. - def metric_fn(labels, logits): - predictions = tf.argmax(logits, 1) - return { - 'accuracy': tf.metrics.precision( - labels=labels, predictions=predictions), - } - - # Your model Fn which runs on TPU (eval_metrics is list in this example) - def model_fn(features, labels, mode, config, params): - ... - logits = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, [labels, logits])) - - # or specify the eval_metrics tensors as dict. - def model_fn(features, labels, mode, config, params): - ... - final_layer_output = ... - - if mode = tf.estimator.ModeKeys.EVAL: - return tpu_estimator.TPUEstimatorSpec( - mode=mode, - loss=loss, - eval_metrics=(metric_fn, { - 'labels': labels, - 'logits': final_layer_output, - })) - ``` - - Prediction - ========== - - Prediction on TPU is an experimental feature to support large batch inference. - It is not designed for latency-critical system. In addition, due to some - usability issues, for prediction with small dataset, CPU `.predict`, i.e., - creating a new `TPUEstimator` instance with `use_tpu=False`, might be more - convenient. - - Note: In contrast to TPU training/evaluation, the `input_fn` for prediction - *should* raise an end-of-input exception (`OutOfRangeError` or - `StopIteration`), which serves as the stopping signal to `TPUEstimator`. To be - precise, the ops created by `input_fn` produce one batch of the data. - The `predict()` API processes one batch at a time. When reaching the end of - the data source, an end-of-input exception should be raised by one of these - operations. The user usually does not need to do this manually. As long as the - dataset is not repeated forever, the `tf.data` API will raise an end-of-input - exception automatically after the last batch has been produced. - - Note: Estimator.predict returns a Python generator. Please consume all the - data from the generator so that TPUEstimator can shutdown the TPU system - properly for user. - - Current limitations: - -------------------- - 1. TPU prediction only works on a single host (one TPU worker). - - 2. `input_fn` must return a `Dataset` instance rather than `features`. In - fact, .train() and .evaluate() also support Dataset as return value. - - Example (MNIST): - ---------------- - ``` - height = 32 - width = 32 - total_examples = 100 - - def predict_input_fn(params): - batch_size = params['batch_size'] - - images = tf.random_uniform( - [total_examples, height, width, 3], minval=-1, maxval=1) - - dataset = tf.data.Dataset.from_tensor_slices(images) - dataset = dataset.map(lambda images: {'image': images}) - - dataset = dataset.batch(batch_size) - return dataset - - def model_fn(features, labels, params, mode): - # Generate predictions, called 'output', from features['image'] - - if mode == tf.estimator.ModeKeys.PREDICT: - return tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, - predictions={ - 'predictions': output, - 'is_padding': features['is_padding'] - }) - - tpu_est = TPUEstimator( - model_fn=model_fn, - ..., - predict_batch_size=16) - - # Fully consume the generator so that TPUEstimator can shutdown the TPU - # system. - for item in tpu_est.predict(input_fn=input_fn): - # Filter out item if the `is_padding` is 1. - # Process the 'predictions' - ``` - - Exporting - ========= - - `export_savedmodel` exports 2 metagraphs, one with `tag_constants.SERVING`, - and another with `tag_constants.SERVING` and `tag_constants.TPU`. - At serving time, these tags are used to select metagraph to load. - - Before running the graph on TPU, TPU system needs to be initialized. If - TensorFlow Serving model-server is used, this is done automatically. If - not, please call `session.run(tpu.initialize_system())`. - - `tpu.outside_compilation` can be used to wrap TPU incompatible ops in - `model_fn`. - - Example: - ---------------- - - ``` - def model_fn(features, labels, mode, config, params): - ... - logits = ... - export_outputs = { - 'logits': export_output_lib.PredictOutput( - {'logits': logits}) - } - - def host_call(logits): - class_ids = math_ops.argmax(logits) - classes = string_ops.as_string(class_ids) - export_outputs['classes'] = - export_output_lib.ClassificationOutput(classes=classes) - - tpu.outside_compilation(host_call, logits) - - ... - ``` - - """ - - def __init__(self, - model_fn=None, - train_cache_fn=None, - eval_cache_fn=None, - model_dir=None, - config=None, - params=None, - use_tpu=True, - train_batch_size=None, - eval_batch_size=None, - predict_batch_size=None, - batch_axis=None, - eval_on_tpu=True, - export_to_tpu=True, - warm_start_from=None): - """Constructs an `TPUEstimator` instance. - - Args: - model_fn: Model function as required by `Estimator` which returns - EstimatorSpec or TPUEstimatorSpec. `training_hooks`, 'evaluation_hooks', - and `prediction_hooks` must not capure any TPU Tensor inside the model_fn. - model_dir: Directory to save model parameters, graph and etc. This can - also be used to load checkpoints from the directory into a estimator to - continue training a previously saved model. If `None`, the model_dir in - `config` will be used if set. If both are set, they must be same. If - both are `None`, a temporary directory will be used. - config: An `tpu_config.RunConfig` configuration object. Cannot be `None`. - params: An optional `dict` of hyper parameters that will be passed into - `input_fn` and `model_fn`. Keys are names of parameters, values are - basic python types. There are reserved keys for `TPUEstimator`, - including 'batch_size'. - use_tpu: A bool indicating whether TPU support is enabled. Currently, - - TPU training and evaluation respect this bit, but eval_on_tpu can - override execution of eval. See below. - - Predict still happens on CPU. - train_batch_size: An int representing the global training batch size. - TPUEstimator transforms this global batch size to a per-shard batch - size, as params['batch_size'], when calling `input_fn` and `model_fn`. - Cannot be `None` if `use_tpu` is `True`. - Must be divisible by total number of replicas. - eval_batch_size: An int representing evaluation batch size. - Must be divisible by total number of replicas. - predict_batch_size: An int representing the prediction batch size. - Must be divisible by total number of replicas. - batch_axis: A python tuple of int values describing how each tensor - produced by the Estimator `input_fn` should be split across the TPU - compute shards. For example, if your input_fn produced (images, labels) - where the images tensor is in `HWCN` format, your shard dimensions would - be [3, 0], where 3 corresponds to the `N` dimension of your images - Tensor, and 0 corresponds to the dimension along which to split the - labels to match up with the corresponding images. If None is supplied, - and per_host_input_for_training is True, batches will be sharded based - on the major dimension. If tpu_config.per_host_input_for_training is - False or `PER_HOST_V2`, batch_axis is ignored. - eval_on_tpu: If False, evaluation runs on CPU or GPU. In this case, the - model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. - export_to_tpu: If True, `export_savedmodel()` exports a metagraph for - serving on TPU besides the one on CPU. - warm_start_from: Optional string filepath to a checkpoint or SavedModel to - warm-start from, or a `tf.estimator.WarmStartSettings` - object to fully configure warm-starting. If the string - filepath is provided instead of a `WarmStartSettings`, - then all variables are warm-started, and it is assumed - that vocabularies and Tensor names are unchanged. - - Raises: - ValueError: `params` has reserved keys already. - """ - if config is None or not isinstance(config, tpu_config.RunConfig): - raise ValueError( - '`config` must be provided with type `tpu_config.RunConfig`') - - if params is not None and any(k in params for k in _RESERVED_PARAMS_KEYS): - raise ValueError('{} are reserved keys but existed in params {}.'.format( - _RESERVED_PARAMS_KEYS, params)) - - if use_tpu: - # Perform some very basic validations. More validations will be found in - # _InternalTPUContext. - if train_batch_size is None: - raise ValueError('`train_batch_size` cannot be `None`') - util_lib.check_positive_integer(train_batch_size, 'train_batch_size') - - if (config.tpu_config.per_host_input_for_training is - tpu_config.InputPipelineConfig.PER_SHARD_V1 and - config.tpu_config.num_cores_per_replica): - raise ValueError( - 'Model parallelism only supports per host input for training. ' - 'Please adjust TPURunconfig.per_host_input_for_training.') - - if eval_batch_size is not None: - util_lib.check_positive_integer(eval_batch_size, 'eval_batch_size') - - if predict_batch_size is not None: - util_lib.check_positive_integer(predict_batch_size, - 'predict_batch_size') - - # Verifies the model_fn signature according to Estimator framework. - estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access - # We cannot store config and params in this constructor as parent - # constructor might change them, such as assigning a temp dir for - # config.model_dir. - model_function = self._augment_model_fn( - model_fn, - train_cache_fn, - eval_cache_fn, - batch_axis) - - # Overwrite log_step_count_steps to disable TensorLoggingHook and - # StepCounterHook from being created in Estimator. TPUEstimator already - # added equivalent hooks in _augment_model_fn above. - self._log_every_n_steps = config.log_step_count_steps - config = config.replace(log_step_count_steps=None) - - # Passing non-None params as wrapped model_fn has it. - params = params or {} - super(TPUEstimator, self).__init__( - model_fn=model_function, - model_dir=model_dir, - config=config, - params=params, - warm_start_from=warm_start_from) - self._iterations_per_training_loop = ( - self._config.tpu_config.iterations_per_loop) - - # All properties passed to _InternalTPUContext are immutable. - # pylint: disable=protected-access - self._ctx = tpu_context._get_tpu_context( - self._config, train_batch_size, - eval_batch_size, predict_batch_size, - use_tpu, - eval_on_tpu) - - self._export_to_tpu = export_to_tpu - - self._is_input_fn_invoked = None - self._rendezvous = {} - - def _add_meta_graph_for_mode(self, - builder, - input_receiver_fn_map, - checkpoint_path, - strip_default_attrs, - save_variables=True, - mode=model_fn_lib.ModeKeys.PREDICT, - export_tags=None, - check_variables=True): - if self._export_to_tpu and mode != model_fn_lib.ModeKeys.PREDICT: - raise NotImplementedError( - 'TPUEstimator only handles mode PREDICT for exporting ' - 'when `export_to_tpu` is `True`; ' - 'got {}.'.format(mode)) - - (super(TPUEstimator, self). - _add_meta_graph_for_mode(builder, - input_receiver_fn_map, - checkpoint_path, - strip_default_attrs, - save_variables, - mode=mode, - export_tags=export_tags, - check_variables=check_variables)) - - if self._export_to_tpu: - input_receiver_fn_map = {_REWRITE_FOR_INFERENCE_MODE: - input_receiver_fn_map[mode]} - export_tags = [tag_constants.SERVING, tag_constants.TPU] - mode = _REWRITE_FOR_INFERENCE_MODE - # See b/110052256 for why `check_variables` is `False`. - (super(TPUEstimator, self). - _add_meta_graph_for_mode(builder, - input_receiver_fn_map, - checkpoint_path, - strip_default_attrs, - save_variables=False, - mode=mode, - export_tags=export_tags, - check_variables=False)) - - def _call_model_fn(self, features, labels, mode, config): - if mode == _REWRITE_FOR_INFERENCE_MODE: - return self._call_model_fn_for_inference(features, labels, mode, config) - else: - return super(TPUEstimator, self)._call_model_fn( - features, labels, mode, config) - - def _call_model_fn_for_inference(self, features, labels, mode, config): - """Wraps `_call_model_fn` for `export_savedmodel`.""" - if mode != _REWRITE_FOR_INFERENCE_MODE: - raise ValueError('mode must be {}; ' - 'got {}.'.format(_REWRITE_FOR_INFERENCE_MODE, mode)) - - capture = _CapturedObject() - - def computation(): - """Compute tpu tensors used in export_outputs. - - Passed to rewrite_for_inference so that model_fn will be called under - the rewriting contexts. Only tpu tensors are returned, but export_outputs - and scaffold are captured. - - Returns: - A list of Tensors used in export_outputs and not marked for - outside_compilation. - """ - # We should only call model fn once and it should be inside `computation` - # so that building the graph will happen under `rewrite_for_inference`. - mode = model_fn_lib.ModeKeys.PREDICT - estimator_spec = self._call_model_fn(features, labels, mode, config) - - # We pick the TPU tensors out from `export_output` and later return them - # from `computation` for rewriting. - tensors_dict = collections.OrderedDict( - (k, _export_output_to_tensors(v)) - for k, v in six.iteritems(estimator_spec.export_outputs) - ) - tensors = nest.flatten(tensors_dict) - tpu_tensors = [t for t in tensors if _is_tpu_tensor(t)] - - # We cannot return anything other than `tpu_tensors` here so we capture - # the rest for later use. - capture.capture((estimator_spec, tensors_dict, tensors)) - return tpu_tensors - - tpu_tensors_on_cpu = tpu.rewrite_for_inference(computation) - estimator_spec, tensors_dict, tensors = capture.get() - - # Reconstruct `tensors`, but with `tpu_tensors` replaced with - # `tpu_tensors_on_cpu`. - new_tensors = [] - for t in tensors: - if _is_tpu_tensor(t): - new_tensors.append(tpu_tensors_on_cpu.pop(0)) - elif t is None: - new_tensors.append(None) - else: - # Only fetching `tpu_tensors_on_cpu` does not trigger - # TPU computation and blocks, so we add the control dependency here. - control_inputs = (tpu_tensors_on_cpu - if isinstance(tpu_tensors_on_cpu, (list, tuple)) - else (tpu_tensors_on_cpu,)) - with ops.control_dependencies(control_inputs): - new_tensors.append(array_ops.identity(t)) - - # Reconstruct `tensors_dict`. - new_tensors_dict = nest.pack_sequence_as(tensors_dict, new_tensors) - # Reconstruct `export_outputs`. - export_outputs = estimator_spec.export_outputs - new_export_outputs = collections.OrderedDict( - (k, _clone_export_output_with_tensors(export_outputs[k], v)) - for k, v in six.iteritems(new_tensors_dict) - ) - - return estimator_spec._replace(export_outputs=new_export_outputs) - - def _create_global_step(self, graph): - """Creates a global step suitable for TPUs. - - Args: - graph: The graph in which to create the global step. - - Returns: - A global step `Tensor`. - - Raises: - ValueError: if the global step tensor is already defined. - """ - return _create_global_step(graph) - - def _convert_train_steps_to_hooks(self, steps, max_steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) - - # On TPU. - if steps is None and max_steps is None: - raise ValueError( - 'For TPU training, one of `steps` or `max_steps` must be set. ' - 'Cannot be both `None`.') - - # Estimator.train has explicit positiveness check. - if steps is not None: - util_lib.check_positive_integer(steps, 'Train steps') - if max_steps is not None: - util_lib.check_positive_integer(max_steps, 'Train max_steps') - - return [ - _TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps) - ] - - def _convert_eval_steps_to_hooks(self, steps): - with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: - if ctx.is_running_on_cpu(): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) - - if steps is None: - raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - - util_lib.check_positive_integer(steps, 'Eval steps') - - return [ - evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access - num_evals=steps), - _SetEvalIterationsHook(steps) - ] - - def _call_input_fn(self, input_fn, mode): - """Calls the input function. - - Args: - input_fn: The input function. - mode: ModeKeys - - Returns: - Either features or (features, labels) where features and labels are: - features - `Tensor` or dictionary of string feature name to `Tensor`. - labels - `Tensor` or dictionary of `Tensor` with labels. - - Raises: - ValueError: if input_fn takes invalid arguments or does not have `params`. - """ - input_fn_args = function_utils.fn_args(input_fn) - config = self.config # a deep copy. - kwargs = {} - if 'params' in input_fn_args: - kwargs['params'] = self.params # a deep copy. - else: - raise ValueError('input_fn ({}) does not include params argument, ' - 'required by TPUEstimator to pass batch size as ' - 'params["batch_size"]'.format(input_fn)) - if 'config' in input_fn_args: - kwargs['config'] = config - - if 'mode' in input_fn_args: - kwargs['mode'] = mode - - # Records the fact input_fn has been invoked. - self._is_input_fn_invoked = True - - with self._ctx.with_mode(mode) as ctx: - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - batch_size_for_input_fn = ctx.batch_size_for_input_fn - if batch_size_for_input_fn is not None: - _add_item_to_params(kwargs['params'], - _BATCH_SIZE_KEY, batch_size_for_input_fn) - - # For export_savedmodel, input_fn is never passed to Estimator. So, - # `is_export_mode` must be False. - if ctx.is_running_on_cpu(is_export_mode=False): - with ops.DEVICE('/DEVICE:CPU:0'): - return input_fn(**kwargs) - - # For TPU computation, input_fn should be invoked in a tf.while_loop for - # performance. While constructing the tf.while_loop, the structure of - # inputs returned by the `input_fn` needs to be recorded. The structure - # includes whether features or labels is dict or single Tensor, dict keys, - # tensor shapes, and dtypes. The recorded structure is used to create the - # infeed dequeue ops, which must be wrapped and passed as a Fn, called - # inside the TPU computation, as the TPU computation is wrapped inside a - # tf.while_loop also. So, we either pass input_fn to model_fn or pass - # dequeue_fn to model_fn. Here, `input_fn` is passed directly as - # `features` in `model_fn` signature. - def _input_fn(ctx): - _add_item_to_params(kwargs['params'], _CTX_KEY, ctx) - return input_fn(**kwargs) - - return _input_fn - - def _validate_features_in_predict_input(self, result): - """Skip the validation. - - For TPUEstimator, we do not need to check the result type. `_InputPipeline` - has stronger check. Parent class's check generates confusing warning msg. - - Args: - result: `features` returned by input_fn. - """ - pass - - def train(self, - input_fn, - hooks=None, - steps=None, - max_steps=None, - saving_listeners=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.TRAIN] = rendezvous - try: - return super(TPUEstimator, self).train( - input_fn=input_fn, hooks=hooks, steps=steps, max_steps=max_steps, - saving_listeners=saving_listeners - ) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('training_loop', sys.exc_info()) - finally: - rendezvous.record_done('training_loop') - rendezvous.raise_errors() - - def evaluate(self, input_fn, steps=None, hooks=None, checkpoint_path=None, - name=None): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.EVAL] = rendezvous - try: - return super(TPUEstimator, self).evaluate( - input_fn, steps=steps, hooks=hooks, checkpoint_path=checkpoint_path, - name=name - ) - except Exception: # pylint: disable=broad-except - rendezvous.record_error('evaluation_loop', sys.exc_info()) - finally: - rendezvous.record_done('evaluation_loop') - rendezvous.raise_errors() - - def predict(self, - input_fn, - predict_keys=None, - hooks=None, - checkpoint_path=None, - yield_single_examples=True): - rendezvous = error_handling.ErrorRendezvous(num_sources=3) - self._rendezvous[model_fn_lib.ModeKeys.PREDICT] = rendezvous - try: - for result in super(TPUEstimator, self).predict( - input_fn=input_fn, - predict_keys=predict_keys, - hooks=hooks, - checkpoint_path=checkpoint_path, - yield_single_examples=yield_single_examples): - yield result - except Exception: # pylint: disable=broad-except - rendezvous.record_error('prediction_loop', sys.exc_info()) - finally: - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - rendezvous.record_done('prediction_loop') - rendezvous.raise_errors() - - def _augment_model_fn(self, model_fn, train_cache_fn, eval_cache_fn, batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - with self._ctx.with_mode(mode) as ctx: - model_fn_wrapper = _ModelFnWrapper(model_fn, train_cache_fn, - eval_cache_fn, config, params, ctx) - - # `input_fn` is called in `train()`, `evaluate()`, and `predict()`, - # but not in `export_savedmodel()`. - if self._is_input_fn_invoked: - is_export_mode = False - else: - is_export_mode = True - - # Clear the bit. - self._is_input_fn_invoked = None - - # examples_hook is added to training_hooks for both CPU and TPU - # execution. - examples_hook = ExamplesPerSecondHook( - ctx.global_batch_size, - output_dir=self.model_dir, - every_n_steps=self._log_every_n_steps) - - if ctx.is_running_on_cpu(is_export_mode=is_export_mode): - logging.info('Running %s on CPU', mode) - estimator_spec = model_fn_wrapper.call_without_tpu( - features, labels, is_export_mode=is_export_mode) - estimator_spec = estimator_spec._replace( - training_hooks=estimator_spec.training_hooks + (examples_hook,)) - return estimator_spec - - assert labels is None, '`labels` passed to `model_fn` must be `None`.' - # TPUEstimator._call_input_fn passes `input_fn` as features to here. - assert callable(features), '`input_fn` is not callable.' - input_fn = features - - input_holders = _InputPipeline(input_fn, batch_axis, ctx) - enqueue_ops, dequeue_fn, input_hooks, run_infeed_loop_on_coordinator = ( - input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) - - graph = ops.get_default_graph() - for enqueue_op in enqueue_ops: - if isinstance(enqueue_op, list): - graph.get_collection_ref(_TPU_ENQUEUE_OPS).extend(enqueue_op) - else: - graph.add_to_collection(_TPU_ENQUEUE_OPS, enqueue_op) - - if mode == model_fn_lib.ModeKeys.TRAIN: - loss, host_call, scaffold, training_hooks = ( - _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn)) - - if model_fn_wrapper._params.get('track_mean', False): - iterations_per_loop_var = _create_or_get_iterations_per_loop() - loss = math_ops.div(loss, - math_ops.cast( - iterations_per_loop_var, - dtype=loss.dtype)) - - host_ops = host_call.create_tpu_hostcall() - if host_ops is None: - host_ops = [] - - shutdown_hooks = [] - shutdown_mode = os.environ.get('TF_TPU_GRACEFUL_SHUTDOWN_MODE', - 'shutdown_worker') - if shutdown_mode: - if shutdown_mode == 'shutdown_worker': - finalizer_hooks = [ - session_support.ShutdownLameWorkers(timeout_ms=60*1000), - ] - elif shutdown_mode == 'shutdown_computation': - finalizer_hooks = [ - session_support.RestartComputation(timeout_ms=60*1000), - ] - else: - raise ValueError('Unknown TF_TPU_GRACEFUL_SHUTDOWN_MODE "%s"' % - shutdown_mode) - - shutdown_hooks.append(session_support.GracefulShutdownHook( - checkpoint_prefix=self.model_dir + '/model.ckpt', - on_shutdown_hooks=finalizer_hooks - )) - - with ops.control_dependencies([loss]): - global_step = array_ops.identity(training.get_global_step()) - hooks = input_hooks + shutdown_hooks - logging_hook_frequency = ( # Divide and round up - (self._log_every_n_steps + - self._config.tpu_config.iterations_per_loop - 1) // - self._config.tpu_config.iterations_per_loop) - - iterations_per_loop = array_ops.identity( - _create_or_get_iterations_per_loop()) - - hooks.extend([ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - host_ops, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode], - ), - InstallSignalHandlerHook(), - training.LoggingTensorHook( - { - 'loss': array_ops.identity(loss), - 'ppl': tensorflow.exp(loss), - 'bpc': loss / tensorflow.constant(math.log(2)), - '#iter/loop': iterations_per_loop, - 'global step': global_step, - }, - every_n_iter=logging_hook_frequency) - ]) - examples_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - hooks.append(examples_hook) - - if training_hooks: - hooks.extend(training_hooks) - - chief_hooks = [] - if (self._config.save_checkpoints_secs or - self._config.save_checkpoints_steps): - checkpoint_hook = training.CheckpointSaverHook( - self.model_dir, - save_secs=self._config.save_checkpoints_secs, - save_steps=self._config.save_checkpoints_steps, - scaffold=scaffold) - checkpoint_hook._set_steps_per_run( # pylint: disable=protected-access - self._config.tpu_config.iterations_per_loop) - chief_hooks.append(checkpoint_hook) - - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops() - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - train_op = control_flow_ops.group(*update_ops) - graph.add_to_collection(_TPU_TRAIN_OP, train_op) - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_chief_hooks=chief_hooks, - training_hooks=hooks, - train_op=train_op, - scaffold=scaffold) - - if mode == model_fn_lib.ModeKeys.EVAL: - total_loss, host_calls, scaffold, eval_hooks = _eval_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_or_get_iterations_per_loop() - mean_loss = math_ops.div(total_loss, - math_ops.cast( - iterations_per_loop_var, - dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects - # all metrics in eval_metric_ops have update_op and calls them one by - # one. The real metric update_ops are invoked in a separated thread. - # So, here give Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), - # reads all variables back from TPU and updates the eval step - # counter properly - internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - host_call_ret = host_calls.create_tpu_hostcall() - eval_metric_ops = {} - eval_update_ops = [] - - for k, v in host_call_ret.get('eval_metrics', {}).items(): - eval_metric_ops[k] = (v[0], dummy_update_op) - eval_update_ops.append(v[1]) - - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - hooks = [ - TPUInfeedOutfeedSessionHook( - ctx, - enqueue_ops, - eval_update_ops + host_ops, - run_infeed_loop_on_coordinator=( - run_infeed_loop_on_coordinator), - rendezvous=self._rendezvous[mode]), - ] + input_hooks - - if eval_hooks: - hooks.extend(eval_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops, - scaffold=scaffold) - - # Predict - assert mode == model_fn_lib.ModeKeys.PREDICT - - (dummy_predict_op, host_calls, - scaffold, prediction_hooks) = _predict_on_tpu_system( - ctx, model_fn_wrapper, dequeue_fn) - with ops.control_dependencies([dummy_predict_op]): - internal_ops_to_run = _sync_variables_ops() - with ops.control_dependencies(internal_ops_to_run): - dummy_predict_op = control_flow_ops.no_op() - - # In train and evaluation, the main TPU program is passed to monitored - # training session to run. Infeed enqueue and outfeed dequeue are - # executed in side threads. This is not the configuration for - # prediction mode. - # - # For prediction, the Estimator executes the EstimatorSpec.predictions - # directly and yield the element (via generator) to call site. So, the - # outfeed based prediction must be passed to MonitoredSession directly. - # Other parts of the TPU execution are organized as follows. - # - # 1. All outfeed based Tensors must be grouped with predictions Tensors - # to form a single invocation. This avoid the issue we might trigger - # multiple outfeeds incorrectly. To achieve this, `host_call` is - # placed in control_dependencies of `stopping_signals`, and - # `stopping_signals` is passed into _StoppingPredictHook, which sets - # the `stopping_signals` as SessionRunArgs. MonitoredSession merges - # all SessionRunArgs with the fetch in session.run together. - # - # 2. The TPU program (dummy_predict_op) and enqueue_ops (infeed Enqueue) - # are grouped together. They will be launched once and only once in - # side threads and they quit naturally according to the SAME stopping - # condition. - enqueue_ops.append(dummy_predict_op) - - host_call_ret = host_calls.create_tpu_hostcall() - if 'host_call' not in host_call_ret: - host_ops = [] - else: - host_ops = host_call_ret['host_call'] - - predictions = host_call_ret['predictions'] - _verify_cross_hosts_transfer_size( - predictions, message=( - 'The estimated size for TPUEstimatorSpec.predictions is too ' - 'large.')) - signals = host_call_ret['signals'] - - with ops.control_dependencies(host_ops): - host_ops = [] # Empty, we do do not need it anymore. - scalar_stopping_signal = _StopSignals.as_scalar_stopping_signal( - signals) - predictions = _PaddingSignals.slice_tensor_or_dict( - predictions, signals) - - hooks = [ - _StoppingPredictHook(scalar_stopping_signal), - TPUInfeedOutfeedSessionHookForPrediction( - ctx, enqueue_ops, host_ops, rendezvous=self._rendezvous[mode]), - ] + input_hooks - - if prediction_hooks: - hooks.extend(prediction_hooks) - - return model_fn_lib.EstimatorSpec( - mode, - prediction_hooks=hooks, - predictions=predictions, - scaffold=scaffold) - - return _model_fn - - -def _is_tpu_tensor(tensor): - if not isinstance(tensor, ops.Tensor): - return False - try: - tensor.op.get_attr(tpu._OUTSIDE_COMPILATION_ATTR) # pylint: disable=protected-access - except ValueError: - return True - else: - return False - - -def _export_output_to_tensors(export_output): - """Get a list of `Tensors` used in `export_output`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - Returns: - a list of tensors used in export_output. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - return [export_output.scores, export_output.classes] - elif isinstance(export_output, export_output_lib.RegressionOutput): - return [export_output.value] - elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output.outputs.values() - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _clone_export_output_with_tensors(export_output, tensors): - """Clones `export_output` but with new `tensors`. - - Args: - export_output: an `ExportOutput` object such as `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - tensors: a list of `Tensors` used to construct a new `export_output`. - - Returns: - A dict similar to `export_output` but with `tensors`. - - Raises: - ValueError: if `export_output` is not one of `ClassificationOutput`, - `RegressionOutput`, or `PredictOutput`. - """ - if isinstance(export_output, export_output_lib.ClassificationOutput): - if len(tensors) != 2: - raise ValueError('tensors must be of length 2; ' - 'got {}.'.format(len(tensors))) - return export_output_lib.ClassificationOutput(*tensors) - elif isinstance(export_output, export_output_lib.RegressionOutput): - if len(tensors) != 1: - raise ValueError('tensors must be of length 1; ' - 'got {}'.format(len(tensors))) - return export_output_lib.RegressionOutput(*tensors) - elif isinstance(export_output, export_output_lib.PredictOutput): - return export_output_lib.PredictOutput( - dict(zip(export_output.outputs.keys(), tensors))) - else: - raise ValueError( - '`export_output` must be have type `ClassificationOutput`, ' - '`RegressionOutput`, or `PredictOutput`; got {}.'.format(export_output)) - - -def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_eval_step, host_calls, captured_scaffold_fn, captured_eval_hooks - ) = model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn) - - def multi_tpu_eval_steps_on_single_shard(): - loop_vars = [_ZERO_LOSS] - if model_fn_wrapper._eval_cache_fn is not None: - batch_size = ctx.global_batch_size - num_shards = ctx._config._tpu_config.num_shards - loop_vars += model_fn_wrapper._eval_cache_fn(batch_size // num_shards) - - return training_loop.repeat( - iterations_per_loop_var, - single_tpu_eval_step, - loop_vars) - - ret = tpu.shard( - multi_tpu_eval_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - loss = ret[0] - - scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_calls, scaffold, captured_eval_hooks.get() - - -def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - iterations_per_loop_var = _create_or_get_iterations_per_loop() - - (single_tpu_train_step, host_call, captured_scaffold_fn, - captured_training_hooks) = ( - model_fn_wrapper.convert_to_single_tpu_train_step(dequeue_fn)) - - def multi_tpu_train_steps_on_single_shard(): - if model_fn_wrapper._params.get('track_mean', False): - loop_vars = [_ZERO_LOSS] - else: - loop_vars = [_INITIAL_LOSS] - if model_fn_wrapper._train_cache_fn is not None: - batch_size = ctx.global_batch_size - num_shards = ctx._config._tpu_config.num_shards - loop_vars += model_fn_wrapper._train_cache_fn(batch_size // num_shards) - - return training_loop.repeat( - iterations_per_loop_var, - single_tpu_train_step, - loop_vars) - - ret = tpu.shard( - multi_tpu_train_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - loss = ret[0] - - scaffold = _get_scaffold(captured_scaffold_fn) - return loss, host_call, scaffold, captured_training_hooks.get() - - -def _predict_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): - """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - (single_tpu_predict_step, host_calls, captured_scaffold_fn, - captured_predict_hooks - ) = model_fn_wrapper.convert_to_single_tpu_predict_step(dequeue_fn) - - def multi_tpu_predict_steps_on_single_shard(): - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - inputs = [_StopSignals.NON_STOPPING_SIGNAL] - outputs = training_loop.while_loop( - cond, single_tpu_predict_step, inputs=inputs, name=b'loop') - return outputs - - (dummy_predict_op,) = tpu.shard( - multi_tpu_predict_steps_on_single_shard, - inputs=[], - num_shards=ctx.num_replicas, - outputs_from_all_shards=False, - device_assignment=ctx.device_assignment) - - scaffold = _get_scaffold(captured_scaffold_fn) - return dummy_predict_op, host_calls, scaffold, captured_predict_hooks.get() - - -def _wrap_computation_in_while_loop(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def computation(i): - with ops.control_dependencies(op_fn()): - return i + 1 - - iterations_per_loop_var = _create_or_get_iterations_per_loop() - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.DEVICE(device): - iterations = array_ops.identity(iterations_per_loop_var) - return control_flow_ops.while_loop( - lambda i: i < iterations, - computation, [constant_op.constant(0)], - parallel_iterations=1) - - -def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): - """Wraps the ops generated by `op_fn` in tf.while_loop.""" - - def cond(scalar_stopping_signal): - return math_ops.logical_not( - _StopSignals.should_stop(scalar_stopping_signal)) - - def computation(unused_scalar_stopping_signal): - return_value = op_fn() - execute_ops = return_value['ops'] - signals = return_value['signals'] - with ops.control_dependencies(execute_ops): - return _StopSignals.as_scalar_stopping_signal(signals) - - # By setting parallel_iterations=1, the parallel execution in while_loop is - # basically turned off. - with ops.DEVICE(device): - return control_flow_ops.while_loop( - cond, - computation, [_StopSignals.NON_STOPPING_SIGNAL], - parallel_iterations=1) - - -def _validate_tpu_training_graph(): - """Validate graph before running distributed training. - - Raises: - ValueError: If the graph seems invalid for running on DEVICE - """ - operations = ops.get_default_graph().get_operations() - - # Check if there is atleast one CrossReplicaSum operation in the graph - # This should be introduced by using the CrossShardOptimizer wrapper - cross_replica_sum_ops = [ - o for o in operations if o.type == _CROSS_REPLICA_SUM_OP - ] - if not cross_replica_sum_ops: - raise ValueError( - 'CrossShardOptimizer must be used for model training on TPUs.') - - -class _CapturedObject(object): - """A placeholder to capture an object. - - This is useful when we need to capture a Python object in the Tensorflow - control flow body function and use it outside the control flow. - """ - - def __init__(self): - self._object = None - self._captured = False - - def capture(self, o): - if self._captured: - raise RuntimeError( - 'InternalError: Object can capture only once. Please file bug.') - - self._captured = True - self._object = o - - def get(self): - if not self._captured: - raise RuntimeError( - 'InternalError: Object is not captured properly before `get`. ' - 'Please file bug.') - return self._object - - -def _get_scaffold(captured_scaffold_fn): - """Retrieves the Scaffold from `captured_scaffold_fn`.""" - with _CapturingContext(message='Inside scaffold_fn'): - scaffold_fn = captured_scaffold_fn.get() - if scaffold_fn: - scaffold = scaffold_fn() - if scaffold is None: - raise ValueError( - 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') - else: - scaffold = None - - if scaffold: - wrapped_finalize = scaffold.finalize - - def _finalize(): - with _CapturingContext('Inside Scaffold.finalize'): - wrapped_finalize() - - scaffold.finalize = _finalize - return scaffold - - -class _CapturingContext(control_flow_ops.ControlFlowContext): - """Tracks references to Tensors defined in TPU replication.""" - - def __init__(self, message): - control_flow_ops.ControlFlowContext.__init__(self) - self._message = message - - def AddOp(self, op): # pylint: disable=invalid-name - for c in op.inputs: - if tpu._TPU_REPLICATE_ATTR in c.op.node_def.attr: # pylint: disable=protected-access - raise ValueError('{}: Op {} depends on TPU computation {}, ' - 'which is not allowed.'.format(self._message, op, c)) - - def to_control_flow_context_def(self, context_def, export_scope=None): - # pylint: disable=useless-super-delegation - # NOTE(slebedev): the method is required by `ControlFlowContext`. - super(_CapturingContext, self).to_control_flow_context_def( - context_def, export_scope) - - def __enter__(self): - # pylint: disable=protected-access - self._g = ops.get_default_graph() - self._old = self._g._get_control_flow_context() - self._g._set_control_flow_context(self) - # pylint: enable=protected-access - - def __exit__(self, _, __, ___): # pylint: disable=invalid-name - self._g._set_control_flow_context(self._old) # pylint: disable=protected-access - - -class _Inputs(object): - """A data structure representing the input_fn returned values. - - This also supports the returned value from input_fn as `Dataset`. - """ - - def __init__(self, features=None, labels=None, dataset=None, signals=None): - if dataset is not None and (features is not None or labels is not None or - signals is not None): - raise RuntimeError('Internal Error: Either (features and labels) or ' - 'dataset should be provided, not both. Please file ' - 'bug') - - self._features = features - self._labels = labels - self._signals = signals - - self._dataset = dataset - self._iterator = None - - @staticmethod - def from_input_fn(return_values): - """Returns an `_Inputs` instance according to `input_fn` return value.""" - if isinstance(return_values, dataset_ops.Dataset): - dataset = return_values - return _Inputs(dataset=dataset) - - features, labels = _Inputs._parse_inputs(return_values) - return _Inputs(features, labels) - - @staticmethod - def _parse_inputs(return_values): - if isinstance(return_values, tuple): - features, labels = return_values - else: - features, labels = return_values, None - return features, labels - - @property - def is_dataset(self): - """Returns True if the return value from input_fn is Dataset.""" - return self._dataset is not None - - def dataset_initializer_hook(self): - """Returns a `SessionRunHook` to initialize this dataset. - - This must be called before `features_and_labels`. - """ - iterator = self._dataset.make_initializable_iterator() - # pylint: disable=protected-access - hook = estimator_util._DatasetInitializerHook(iterator) - # pylint: enable=protected-access - self._iterator = iterator - return hook - - def features_and_labels(self): - """Gets `features` and `labels`.""" - if self.is_dataset: - if self._iterator is None: - raise RuntimeError('Internal error: Must call dataset_initializer_hook ' - 'before calling features_and_labels(). Please file ' - 'a bug!') - return _Inputs._parse_inputs(self._iterator.get_next()) - - return (self._features, self._labels) - - def signals(self): - return self._signals - - @property - def dataset(self): - return self._dataset - - -class _InputsWithStoppingSignals(_Inputs): - """Inputs with `_StopSignals` inserted into the dataset.""" - - def __init__(self, - dataset, - batch_size, - add_padding=False, - num_invocations_per_step=1): - - assert dataset is not None - user_provided_dataset = dataset.map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=False, batch_size=batch_size, add_padding=add_padding)) - if num_invocations_per_step == 1: - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - else: - # We append (2 * num_invocations_per_step - 1) batches for exhausting the - # user_provided_dataset and stop properly. - # For example, if num_invocations_per_step is 2, we append 3 additional - # padding batches: b1, b2, b3. - # If user_provided_dataset contains two batches: a1, a2 - # Step 1: [a1, a2] - # Step 2: [b1, b2] -> STOP - # If user_provided_dataset contains three batches: a1, a2, a3. - # The training loops: - # Step 1: [a1, a2] - # Step 2: [a3, b1] - # Step 3: [b2, b3] -> STOP. - final_batch_dataset = dataset.take(1).map( - _InputsWithStoppingSignals.insert_stopping_signal( - stop=True, batch_size=batch_size, add_padding=add_padding)) - final_batch_dataset = final_batch_dataset.repeat( - 2 * num_invocations_per_step - 1) - - def _set_mask(data_dict): - signals = data_dict['signals'] - signals['padding_mask'] = array_ops.ones_like(signals['padding_mask']) - data_dict['signals'] = signals - return data_dict - - # Mask out the extra batch. - final_batch_dataset = final_batch_dataset.map(_set_mask) - - dataset = user_provided_dataset.concatenate(final_batch_dataset).prefetch(2) - - super(_InputsWithStoppingSignals, self).__init__(dataset=dataset) - self._current_inputs = None - - def features_and_labels(self): - if self._current_inputs is not None: - raise RuntimeError( - 'Internal Error: The previous inputs have not been properly ' - 'consumed. First call features_and_labels, then call signals.') - - inputs_with_signals = self._iterator.get_next() - features = inputs_with_signals['features'] - labels = inputs_with_signals.get('labels') - - self._current_inputs = inputs_with_signals - return features, labels - - def signals(self): - """Returns the `Signals` from `_Inputs`.""" - if self._current_inputs is None: - raise RuntimeError( - 'Internal Error: The current inputs have not been properly ' - 'generated. First call features_and_labels, then call signals.') - signals = self._current_inputs['signals'] - self._current_inputs = None - return signals - - @staticmethod - def insert_stopping_signal(stop, batch_size, add_padding=False): - """Inserts stopping_signal into dataset via _map_fn. - - Here we change the data structure in the dataset, such that the return value - is a dictionary now and `features`, `labels`, and `signals` are three - distinguished keys in that dict. This provides a better structure, which - eases the process to decompose the inputs (see `features_and_labels`). - - Args: - stop: bool, state of current stopping signals. - batch_size: int, batch size. - add_padding: bool, whether to pad the tensor to full batch size. - - Returns: - A map_fn passed to dataset.map API. - """ - - def _map_fn(*args): - """The map fn to insert signals.""" - if len(args) == 1: - # Unpack the single Tensor/dict argument as features. This is required - # for the input_fn returns no labels. - args = args[0] - features, labels = _Inputs._parse_inputs(args) - new_input_dict = {} - - if add_padding: - padding_mask, features, labels = ( - _PaddingSignals.pad_features_and_labels( - features, labels, batch_size)) - - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - - else: - new_input_dict['features'] = features - if labels is not None: - new_input_dict['labels'] = labels - padding_mask = None - - new_input_dict['signals'] = _StopSignals( - stop=stop, batch_size=batch_size, padding_mask=padding_mask).as_dict() - - return new_input_dict - - return _map_fn - - -class _StopSignals(object): - """Signals class holding all logic to handle TPU stopping condition.""" - - NON_STOPPING_SIGNAL = False - STOPPING_SIGNAL = True - - def __init__(self, stop, batch_size, padding_mask=None): - self._stop = stop - self._batch_size = batch_size - self._padding_mask = padding_mask - - def as_dict(self): - """Returns the signals as Python dict.""" - shape = [self._batch_size, 1] - dtype = dtypes.bool - - if self._stop: - stopping = array_ops.ones(shape=shape, dtype=dtype) - else: - stopping = array_ops.zeros(shape=shape, dtype=dtype) - - signals = {'stopping': stopping} - if self._padding_mask is not None: - signals['padding_mask'] = self._padding_mask - return signals - - @staticmethod - def as_scalar_stopping_signal(signals): - return array_ops.identity(signals['stopping'][0][0]) - - @staticmethod - def should_stop(scalar_stopping_signal): - """Detects whether scalar_stopping_signal indicates stopping.""" - if isinstance(scalar_stopping_signal, ops.Tensor): - # STOPPING_SIGNAL is a constant True. Here, the logical_and is just the TF - # way to express the bool check whether scalar_stopping_signal is True. - return math_ops.logical_and( - scalar_stopping_signal, _StopSignals.STOPPING_SIGNAL) - else: - # For non Tensor case, it is used in SessionRunHook. So, we cannot modify - # the graph anymore. Here, we use pure Python. - return bool(scalar_stopping_signal) - - -class _PaddingSignals(object): - """Signals class holding all logic to handle padding.""" - - @staticmethod - def pad_features_and_labels(features, labels, batch_size): - """Pads out the batch dimension of features and labels.""" - real_batch_size = array_ops.shape( - _PaddingSignals._find_any_tensor(features))[0] - - batch_size_tensor = constant_op.constant(batch_size, dtypes.int32) - - check_greater = check_ops.assert_greater_equal( - batch_size_tensor, real_batch_size, - data=(batch_size_tensor, real_batch_size), - message='The real batch size should not be greater than batch_size.') - - with ops.control_dependencies([check_greater]): - missing_count = batch_size_tensor - real_batch_size - - def pad_single_tensor(tensor): - """Pads out the batch dimension of a tensor to the complete batch_size.""" - rank = len(tensor.shape) - assert rank > 0 - padding = array_ops.stack([[0, missing_count]] + [[0, 0]] * (rank - 1)) - padded_shape = (batch_size,) + tuple(tensor.shape[1:]) - padded_tensor = array_ops.pad(tensor, padding) - padded_tensor.set_shape(padded_shape) - return padded_tensor - - def nest_pad(tensor_or_dict): - return nest.map_structure(pad_single_tensor, tensor_or_dict) - - features = nest_pad(features) - if labels is not None: - labels = nest_pad(labels) - - padding_mask = _PaddingSignals._padding_mask( - real_batch_size, missing_count, batch_size) - - return padding_mask, features, labels - - @staticmethod - def slice_tensor_or_dict(tensor_or_dict, signals): - """Slice the real Tensors according to padding mask in signals.""" - - padding_mask = signals['padding_mask'] - batch_size = array_ops.shape(padding_mask)[0] - - def verify_batch_size(tensor): - check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) - with ops.control_dependencies([check_batch_size]): - return array_ops.identity(tensor) - - def slice_single_tensor(tensor): - rank = len(tensor.shape) - assert rank > 0 - real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) - return verify_batch_size(tensor)[0:real_batch_size] - - # As we split the Tensors to all TPU cores and concat them back, it is - # important to ensure the real data is placed before padded ones, i.e., - # order is preserved. By that, the sliced padding mask should have all 0's. - # If this assertion failed, # the slice logic here would not hold. - sliced_padding_mask = slice_single_tensor(padding_mask) - assert_padding_mask = math_ops.equal( - math_ops.reduce_sum(sliced_padding_mask), 0) - - with ops.control_dependencies([assert_padding_mask]): - should_stop = _StopSignals.should_stop( - _StopSignals.as_scalar_stopping_signal(signals)) - - is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) - - def slice_fn(tensor): - # If the current batch is full batch or part of stopping signals, we do - # not need to slice to save performance. - return control_flow_ops.cond( - math_ops.logical_or(should_stop, is_full_batch), - (lambda: verify_batch_size(tensor)), - (lambda: slice_single_tensor(tensor))) - - return nest.map_structure(slice_fn, tensor_or_dict) - - @staticmethod - def _find_any_tensor(batch_features): - tensors = [x for x in nest.flatten(batch_features) - if isinstance(x, ops.Tensor)] - if not tensors: - raise ValueError('Cannot find any Tensor in features dict.') - return tensors[0] - - @staticmethod - def _padding_mask(real_batch_size, missing_count, batch_size): - padding_mask = array_ops.concat( - [ - array_ops.zeros((real_batch_size,), dtype=dtypes.int32), - array_ops.ones((missing_count,), dtype=dtypes.int32) - ], - axis=0) - padding_mask.set_shape((batch_size,)) - return padding_mask - - -def _verify_cross_hosts_transfer_size(tensor_dict, message): - total_size = 0 - tensor_structure = {} - for key, tensor in tensor_dict.items(): - shape = tensor.shape - size = np.product(shape) * tensor.dtype.size - tensor_structure[key] = shape - total_size += size - if total_size >= _ONE_GIGABYTE: - raise ValueError( - '{} The transfer size is larger than the protobuf limit. Please ' - 'consider to use Tensors with smaller shapes or reduce batch ' - 'size. Given:\n' - '{}'.format(message, '\n'.join([ - ' -- Key: {}, Shape: {}'.format(k, v) - for k, v in tensor_structure.items()]))) - - -def _add_item_to_params(params, key, value): - """Adds a new item into `params`.""" - if isinstance(params, hparam.HParams): - # For HParams, we need to use special API. - if key in params: - params.set_hparam(key, value) - else: - params.add_hparam(key, value) - else: - # Now params is Python dict. - params[key] = value - - -def export_estimator_savedmodel(estimator, - export_dir_base, - serving_input_receiver_fn, - assets_extra=None, - as_text=False, - checkpoint_path=None, - strip_default_attrs=False): - """Export `Estimator` trained model for TPU inference. - - Args: - estimator: `Estimator` with which model has been trained. - export_dir_base: A string containing a directory in which to create - timestamped subdirectories containing exported SavedModels. - serving_input_receiver_fn: A function that takes no argument and - returns a `ServingInputReceiver` or `TensorServingInputReceiver`. - assets_extra: A dict specifying how to populate the assets.extra directory - within the exported SavedModel, or `None` if no extra assets are needed. - as_text: whether to write the SavedModel proto in text format. - checkpoint_path: The checkpoint path to export. If `None` (the default), - the most recent checkpoint found within the model directory is chosen. - strip_default_attrs: Boolean. If `True`, default-valued attributes will be - removed from the NodeDefs. - - Returns: - The string path to the exported directory. - """ - # `TPUEstimator` requires `tpu_config.RunConfig`, so we cannot use - # `estimator.config`. - config = tpu_config.RunConfig(model_dir=estimator.model_dir) - est = TPUEstimator( - estimator._model_fn, # pylint: disable=protected-access - config=config, - params=estimator.params, - use_tpu=True, - train_batch_size=2048, # Does not matter. - eval_batch_size=2048, # Does not matter. - ) - return est.export_savedmodel(export_dir_base, serving_input_receiver_fn, - assets_extra, - as_text, - checkpoint_path, - strip_default_attrs) diff --git a/transformer-xl/tf/train.py b/transformer-xl/tf/train.py deleted file mode 100644 index 5ad7449..0000000 --- a/transformer-xl/tf/train.py +++ /dev/null @@ -1,462 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math -import time - -from absl import flags -import absl.logging as _logging # pylint: disable=unused-import - -from six.moves import xrange # pylint: disable=redefined-builtin - -import tensorflow as tf -from tensorflow.gfile import Exists as exists -import model -import data_utils -import tpu_estimator - -import numpy as np -from time import sleep - - -# TPU parameters -flags.DEFINE_string("master", default=None, - help="master") -flags.DEFINE_string("tpu", default=None, - help="The Cloud TPU to use for training. This should be either the name " - "used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.") -flags.DEFINE_string("gcp_project", default=None, - help="Project name for the Cloud TPU-enabled project. If not specified, " - "we will attempt to automatically detect the GCE project from metadata.") -flags.DEFINE_string("tpu_zone",default=None, - help="GCE zone where the Cloud TPU is located in. If not specified, we " - "will attempt to automatically detect the GCE project from metadata.") -flags.DEFINE_bool("use_tpu", default=True, - help="Use TPUs rather than plain CPUs.") -flags.DEFINE_integer("num_hosts", default=1, - help="number of TPU hosts") -flags.DEFINE_integer("num_core_per_host", default=8, - help="number of cores per host") - -# Experiment (data/checkpoint/directory) parameters -flags.DEFINE_string("data_dir", default="", - help="Path to tf-records directory.") -flags.DEFINE_string("record_info_dir", default="", - help="Path to local directory containing filenames.txt.") -flags.DEFINE_string("corpus_info_path", default="", - help="Path to corpus-info.json file.") -flags.DEFINE_string("model_dir", default=None, - help="Estimator model_dir.") -flags.DEFINE_bool("do_eval", default=False, - help="Whether to run eval on the dev set.") -flags.DEFINE_bool("track_mean", default=True, - help="Trace mean loss during training.") -flags.DEFINE_string("eval_ckpt_path", None, - help="Checkpoint path for evaluation." - "If set, model_dir will be ignored." - "If unset, will use the latest ckpt in model_dir.") -flags.DEFINE_string("warm_start_path", None, - help="Checkpoint path for warm start." - "If set, will clear Adam states." - "Note that the new model_dir should be different" - " from warm_start_path.") - -# Optimization paramenters -flags.DEFINE_float("learning_rate", default=2.5e-4, - help="Maximum learning rate.") -flags.DEFINE_float("clip", default=0.25, - help="Gradient clipping value.") -# for cosine decay -flags.DEFINE_float("min_lr_ratio", default=0.01, - help="Minimum ratio learning rate.") -flags.DEFINE_integer("warmup_steps", default=0, - help="Number of steps for linear lr warmup.") - -# Training parameters -flags.DEFINE_integer("train_batch_size", default=60, - help="Size of train batch.") -flags.DEFINE_integer("eval_batch_size", default=60, - help="Size of valid batch.") -flags.DEFINE_integer("train_steps", default=100000, - help="Total number of training steps.") -flags.DEFINE_integer("iterations", default=500, - help="Number of iterations per repeat loop.") -flags.DEFINE_integer("save_steps", default=10000, - help="number of steps for model checkpointing.") - -# Evaluation parameters -flags.DEFINE_integer("max_eval_batch", default=-1, - help="Set -1 to turn off. Only used in test mode.") -flags.DEFINE_bool("do_eval_only", default=False, - help="Run evaluation only.") -flags.DEFINE_integer("start_eval_steps", default=10000, - help="Which checkpoint to start with in `do_eval_only` mode.") -flags.DEFINE_string("eval_split", "valid", - help="Which data split to evaluate.") - -# Model paramenters -flags.DEFINE_integer("tgt_len", default=70, - help="Number of steps to predict") -flags.DEFINE_integer("mem_len", default=70, - help="Number of steps to cache") -flags.DEFINE_bool("same_length", default=False, - help="Same length attention") -flags.DEFINE_integer("clamp_len", default=-1, - help="Clamp length") - -flags.DEFINE_integer("n_layer", default=6, - help="Number of layers.") -flags.DEFINE_integer("d_model", default=500, - help="Dimension of the model.") -flags.DEFINE_integer("d_embed", default=500, - help="Dimension of the embeddings.") -flags.DEFINE_integer("n_head", default=10, - help="Number of attention heads.") -flags.DEFINE_integer("d_head", default=50, - help="Dimension of each attention head.") -flags.DEFINE_integer("d_inner", default=1000, - help="Dimension of inner hidden size in positionwise feed-forward.") -flags.DEFINE_float("dropout", default=0.1, - help="Dropout rate.") -flags.DEFINE_float("dropatt", default=0.1, - help="Attention dropout rate.") -flags.DEFINE_bool("untie_r", default=False, - help="untie r_w_bias and r_r_bias") - -# Adaptive Softmax / Embedding -flags.DEFINE_bool("tie_weight", default=True, - help="Tie embedding and softmax weight.") -flags.DEFINE_integer("div_val", default=1, - help="Divide the embedding size by this val for each bin") -flags.DEFINE_bool("proj_share_all_but_first", default=False, - help="True to share all but first projs, False not to share.") -flags.DEFINE_bool("proj_same_dim", default=True, - help="Project the bin with the same dimension.") - -# Parameter initialization -flags.DEFINE_enum("init", default="normal", - enum_values=["normal", "uniform"], - help="Initialization method.") -flags.DEFINE_float("init_std", default=0.02, - help="Initialization std when init is normal.") -flags.DEFINE_float("proj_init_std", default=0.01, - help="Initialization std for embedding projection.") -flags.DEFINE_float("init_range", default=0.1, - help="Initialization std when init is uniform.") - - -FLAGS = flags.FLAGS - -def metric_fn(loss): - """Evaluation metric Fn which runs on CPU.""" - perplexity = tf.exp(tf.reduce_mean(loss)) - bpc = tf.reduce_mean(loss) / tf.constant(math.log(2)) - return { - "perplexity": tf.metrics.mean(perplexity), - "bpc": tf.metrics.mean(bpc), - } - - -def get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes): - def model_fn(features, labels, mode, params): - is_training = (mode == tf.estimator.ModeKeys.TRAIN) - - - batch_size = params["batch_size"] - - mems = params["cache"] - inp = tf.transpose(features["inputs"], [1, 0]) - tgt = tf.transpose(features["labels"], [1, 0]) - - bin_sizes = train_bin_sizes if is_training else eval_bin_sizes - if bin_sizes: - inp_perms = [tf.transpose(features["inp_mask"], [1, 0])] - tgt_perms = [tf.transpose(features["tgt_mask"], [1, 0])] - - head_tgt = tf.transpose(features["head_labels"], [1, 0]) - - for b in range(len(bin_sizes)): - inp_perm = tf.transpose(features["inp_perm_{}".format(b)], [1, 0, 2]) - tgt_perm = tf.transpose(features["tgt_perm_{}".format(b)], [1, 0, 2]) - - inp_perms.append(inp_perm) - tgt_perms.append(tgt_perm) - else: - inp_perms, tgt_perms, head_tgt = None, None, None - - if FLAGS.init == "uniform": - initializer = tf.initializers.random_uniform( - minval=-FLAGS.init_range, - maxval=FLAGS.init_range, - seed=None) - elif FLAGS.init == "normal": - initializer = tf.initializers.random_normal( - stddev=FLAGS.init_std, - seed=None) - proj_initializer = tf.initializers.random_normal( - stddev=FLAGS.proj_init_std, - seed=None) - - tie_projs = [False for _ in range(len(cutoffs) + 1)] - if FLAGS.proj_share_all_but_first: - for i in range(1, len(tie_projs)): - tie_projs[i] = True - - tf.logging.info("Vocab size : {}".format(n_token)) - tf.logging.info("Batch size : {}".format(batch_size)) - - loss, new_mems = model.transformer( - dec_inp=inp, - target=tgt, - mems=mems, - n_token=n_token, - n_layer=FLAGS.n_layer, - d_model=FLAGS.d_model, - d_embed=FLAGS.d_embed, - n_head=FLAGS.n_head, - d_head=FLAGS.d_head, - d_inner=FLAGS.d_inner, - dropout=FLAGS.dropout, - dropatt=FLAGS.dropatt, - initializer=initializer, - is_training=is_training, - mem_len=FLAGS.mem_len, - cutoffs=cutoffs, - div_val=FLAGS.div_val, - tie_projs=tie_projs, - input_perms=inp_perms, - target_perms=tgt_perms, - head_target=head_tgt, - same_length=FLAGS.same_length, - clamp_len=FLAGS.clamp_len, - use_tpu=FLAGS.use_tpu, - untie_r=FLAGS.untie_r, - proj_same_dim=FLAGS.proj_same_dim) - - total_loss = tf.reduce_mean(loss) - - if mode == tf.estimator.ModeKeys.EVAL: - if FLAGS.use_tpu: - with tf.colocate_with(total_loss): - total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) \ - / FLAGS.num_hosts / FLAGS.num_core_per_host - metric_loss = tf.tile(tf.reshape(total_loss, [1, 1]), [batch_size, 1]) - eval_spec = tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, - loss=total_loss, - eval_metrics=(metric_fn, [metric_loss])) - - eval_spec.cache = new_mems - - return eval_spec - - # Configuring the optimization step. - global_step = tf.train.get_global_step() - - # increase the learning rate linearly - if FLAGS.warmup_steps > 0: - warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ - * FLAGS.learning_rate - else: - warmup_lr = 0.0 - - # number of parameters - num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()]) - tf.logging.info("#params: {}".format(num_params)) - - # format_str = '{{:<{0}s}}\t{{}}'.format( - # max([len(v.name) for v in tf.trainable_variables()])) - # for v in tf.trainable_variables(): - # tf.logging.info(format_str.format(v.name, v.get_shape())) - - - # decay the learning rate using the cosine schedule - decay_lr = tf.train.cosine_decay( - FLAGS.learning_rate, - global_step=global_step-FLAGS.warmup_steps, - decay_steps=FLAGS.train_steps-FLAGS.warmup_steps, - alpha=FLAGS.min_lr_ratio) - - learning_rate = tf.where(global_step < FLAGS.warmup_steps, - warmup_lr, decay_lr) - - if FLAGS.use_tpu: - optimizer = tf.contrib.tpu.CrossShardOptimizer( - tf.train.AdamOptimizer(learning_rate=learning_rate)) - #GradientDescentOptimizer - else: - optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) - - grads_and_vars = optimizer.compute_gradients(total_loss) - gradients, variables = zip(*grads_and_vars) - clipped, _ = tf.clip_by_global_norm(gradients, FLAGS.clip) - train_op = optimizer.apply_gradients( - zip(clipped, variables), global_step=tf.train.get_global_step()) - - # Constucting TPUEstimatorSpec with cache. - train_spec = tf.contrib.tpu.TPUEstimatorSpec( - mode=mode, loss=total_loss, train_op=train_op) - - if FLAGS.mem_len < FLAGS.tgt_len: - new_mems = [new_mems[: FLAGS.mem_len] for mem_t in new_mems] - train_spec.cache = new_mems - - return train_spec - - return model_fn - - -def get_cache_fn(mem_len): - - def cache_fn(batch_size): - mems = [] - for l in xrange(FLAGS.n_layer): - if mem_len > 0: - mems.append( - tf.zeros([mem_len, batch_size, FLAGS.d_model], dtype=tf.float32)) - else: - mems.append(tf.zeros([mem_len], dtype=tf.float32)) - - return mems - - return cache_fn - - -def main(unused_argv): - del unused_argv # Unused - - tf.logging.set_verbosity(tf.logging.INFO) - - # Get corpus info - corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path) - n_token = corpus_info["vocab_size"] - cutoffs = corpus_info["cutoffs"][1:-1] - - if FLAGS.save_steps == 0: - FLAGS.save_steps = None - - if not FLAGS.do_eval_only: - # Get train input function - train_input_fn, train_record_info = data_utils.get_input_fn( - record_info_dir=FLAGS.record_info_dir, - split="train", - per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts, - tgt_len=FLAGS.tgt_len, - num_core_per_host=FLAGS.num_core_per_host, - num_hosts=FLAGS.num_hosts, - use_tpu=FLAGS.use_tpu) - train_bin_sizes = train_record_info["bin_sizes"] - num_train_batch = train_record_info["num_batch"] - - # Get train cache function - train_cache_fn = get_cache_fn(FLAGS.mem_len) - else: - train_bin_sizes = [] - num_train_batch = None - train_cache_fn = None - - if FLAGS.do_eval or FLAGS.do_eval_only: - assert FLAGS.num_hosts == 1 - # Get eval input function - eval_input_fn, eval_record_info = data_utils.get_input_fn( - record_info_dir=FLAGS.record_info_dir, - split=FLAGS.eval_split, - per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts, - tgt_len=FLAGS.tgt_len, - num_core_per_host=FLAGS.num_core_per_host, - num_hosts=FLAGS.num_hosts, - use_tpu=FLAGS.use_tpu) - eval_bin_sizes = eval_record_info["bin_sizes"] - num_eval_batch = eval_record_info["num_batch"] - - if FLAGS.max_eval_batch > 0: - num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch) - - # Get eval cache function - eval_cache_fn = get_cache_fn(FLAGS.mem_len) - model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes) - else: - eval_cache_fn = None - model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, []) - - ##### Create estimator - # TPU Configuration - tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver( - FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project) - - per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2 - run_config = tf.contrib.tpu.RunConfig( - cluster=tpu_cluster_resolver, - model_dir=FLAGS.model_dir, - session_config=tf.ConfigProto( - allow_soft_placement=True, log_device_placement=True), - tpu_config=tf.contrib.tpu.TPUConfig( - iterations_per_loop=FLAGS.iterations, - num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts, - per_host_input_for_training=per_host_input), - keep_checkpoint_max=100000, # effectively save all checkpoints - save_checkpoints_secs=None, - save_checkpoints_steps=FLAGS.save_steps - ) - - # warm start - warm_start_from = None - if FLAGS.warm_start_path is not None: - warm_start_from = tf.estimator.WarmStartSettings( - ckpt_to_initialize_from=FLAGS.warm_start_path) - - # TPU Estimator - estimator = tpu_estimator.TPUEstimator( - model_fn=model_fn, - train_cache_fn=train_cache_fn, - eval_cache_fn=eval_cache_fn, - use_tpu=FLAGS.use_tpu, - config=run_config, - params={"data_dir":FLAGS.data_dir, "track_mean":FLAGS.track_mean}, - train_batch_size=FLAGS.train_batch_size, - eval_batch_size=FLAGS.eval_batch_size, - warm_start_from=warm_start_from) - - if FLAGS.do_eval_only: - if FLAGS.eval_ckpt_path is not None: - ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch, - checkpoint_path=FLAGS.eval_ckpt_path) - tf.logging.info("=" * 200) - log_str = "Eval results | " - for key, val in ret.items(): - log_str += "{} {} | ".format(key, val) - tf.logging.info(log_str) - tf.logging.info("=" * 200) - else: - ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir) - eval_results = [] - for eval_checkpoint in ckpt_state.all_model_checkpoint_paths: - if not exists(eval_checkpoint + ".index"): continue - global_step = int(eval_checkpoint.split("-")[-1]) - if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps: - continue - ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch, - checkpoint_path=eval_checkpoint) - eval_results.append(ret) - - eval_results.sort(key = lambda x: x["perplexity"]) - - tf.logging.info("=" * 200) - log_str = "Best results | " - for key, val in eval_results[0].items(): - log_str += "{} {} | ".format(key, val) - tf.logging.info(log_str) - tf.logging.info("=" * 200) - else: - if not FLAGS.do_eval: - estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps) - else: - for step in range(0, FLAGS.train_steps, num_train_batch): - train_steps = min(FLAGS.train_steps - step, num_train_batch) - estimator.train(input_fn=train_input_fn, steps=train_steps) - estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch) - - -if __name__ == "__main__": - tf.app.run() diff --git a/transformer-xl/tf/train_gpu.py b/transformer-xl/tf/train_gpu.py deleted file mode 100644 index bf83b79..0000000 --- a/transformer-xl/tf/train_gpu.py +++ /dev/null @@ -1,475 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import math -import time - -from absl import flags -import absl.logging as _logging # pylint: disable=unused-import - -import tensorflow as tf -import model -import data_utils - -from gpu_utils import assign_to_gpu, average_grads_and_vars - -import numpy as np - -# GPU config -flags.DEFINE_integer("num_hosts", default=1, - help="Number of TPU hosts") -flags.DEFINE_integer("num_core_per_host", default=8, - help="Number of cores per host") - -# Experiment (data/checkpoint/directory) config -flags.DEFINE_string("data_dir", default="", - help="Path to tf-records directory.") -flags.DEFINE_string("record_info_dir", default="", - help="Path to local directory containing filenames.txt.") -flags.DEFINE_string("corpus_info_path", default="", - help="Path to corpus-info.json file.") -flags.DEFINE_string("model_dir", default=None, - help="Estimator model_dir.") -flags.DEFINE_bool("do_train", default=True, - help="Whether to run training.") -flags.DEFINE_bool("do_eval", default=False, - help="Whether to run eval on the dev set.") -flags.DEFINE_string("eval_ckpt_path", None, - help="Checkpoint path for do_test evaluation." - "If set, model_dir will be ignored." - "If unset, will use the latest ckpt in model_dir.") -flags.DEFINE_string("warm_start_path", None, - help="Checkpoint path for warm start." - "If set, will clear Adam states." - "Note that the new model_dir should be different" - " from warm_start_path.") - -# Optimization config -flags.DEFINE_float("learning_rate", default=2.5e-4, - help="Maximum learning rate.") -flags.DEFINE_float("clip", default=0.25, - help="Gradient clipping value.") -# for cosine decay -flags.DEFINE_float("min_lr_ratio", default=0.004, - help="Minimum ratio learning rate.") -flags.DEFINE_integer("warmup_steps", default=0, - help="Number of steps for linear lr warmup.") - -# Training config -flags.DEFINE_integer("train_batch_size", default=60, - help="Size of train batch.") -flags.DEFINE_integer("eval_batch_size", default=60, - help="Size of valid batch.") -flags.DEFINE_integer("train_steps", default=100000, - help="Total number of training steps.") -flags.DEFINE_integer("iterations", default=500, - help="Number of iterations per repeat loop.") -flags.DEFINE_integer("save_steps", default=10000, - help="number of steps for model checkpointing.") - -# Evaluation config -flags.DEFINE_bool("do_test", default=False, - help="Run on the test set.") -flags.DEFINE_integer("max_eval_batch", default=-1, - help="Set -1 to turn off. Only used in test mode.") -flags.DEFINE_bool("do_eval_only", default=False, - help="Run evaluation only.") -flags.DEFINE_integer("start_eval_steps", default=10000, - help="Which checkpoint to start with in `do_eval_only` mode.") -flags.DEFINE_string("eval_split", "valid", - help="Which data split to evaluate.") - -# Model config -flags.DEFINE_integer("tgt_len", default=70, - help="Number of steps to predict") -flags.DEFINE_integer("mem_len", default=70, - help="Number of steps to cache") -flags.DEFINE_bool("same_length", default=False, - help="Same length attention") -flags.DEFINE_integer("clamp_len", default=-1, - help="Clamp length") - -flags.DEFINE_integer("n_layer", default=6, - help="Number of layers.") -flags.DEFINE_integer("d_model", default=500, - help="Dimension of the model.") -flags.DEFINE_integer("d_embed", default=500, - help="Dimension of the embeddings.") -flags.DEFINE_integer("n_head", default=10, - help="Number of attention heads.") -flags.DEFINE_integer("d_head", default=50, - help="Dimension of each attention head.") -flags.DEFINE_integer("d_inner", default=1000, - help="Dimension of inner hidden size in positionwise feed-forward.") -flags.DEFINE_float("dropout", default=0.1, - help="Dropout rate.") -flags.DEFINE_float("dropatt", default=0.1, - help="Attention dropout rate.") -flags.DEFINE_bool("untie_r", default=False, - help="untie r_w_bias and r_r_bias") - -# Adaptive Softmax / Embedding -flags.DEFINE_bool("tie_weight", default=True, - help="Tie embedding and softmax weight.") -flags.DEFINE_integer("div_val", default=1, - help="Divide the embedding size by this val for each bin") -flags.DEFINE_bool("proj_share_all_but_first", default=False, - help="True to share all but first projs, False not to share.") -flags.DEFINE_bool("proj_same_dim", default=True, - help="Project the bin with the same dimension.") - -# Parameter initialization -flags.DEFINE_enum("init", default="normal", - enum_values=["normal", "uniform"], - help="Initialization method.") -flags.DEFINE_float("init_std", default=0.02, - help="Initialization std when init is normal.") -flags.DEFINE_float("proj_init_std", default=0.01, - help="Initialization std for embedding projection.") -flags.DEFINE_float("init_range", default=0.1, - help="Initialization std when init is uniform.") - -FLAGS = flags.FLAGS - -def get_model_fn(n_token, cutoffs): - def model_fn(inp, tgt, mems, is_training): - inp = tf.transpose(inp, [1, 0]) - tgt = tf.transpose(tgt, [1, 0]) - - if FLAGS.init == "uniform": - initializer = tf.initializers.random_uniform( - minval=-FLAGS.init_range, - maxval=FLAGS.init_range, - seed=None) - elif FLAGS.init == "normal": - initializer = tf.initializers.random_normal( - stddev=FLAGS.init_std, - seed=None) - proj_initializer = tf.initializers.random_normal( - stddev=FLAGS.proj_init_std, - seed=None) - - tie_projs = [False for _ in range(len(cutoffs) + 1)] - if FLAGS.proj_share_all_but_first: - for i in range(1, len(tie_projs)): - tie_projs[i] = True - - loss, new_mems = model.transformer( - dec_inp=inp, - target=tgt, - mems=mems, - n_token=n_token, - n_layer=FLAGS.n_layer, - d_model=FLAGS.d_model, - d_embed=FLAGS.d_embed, - n_head=FLAGS.n_head, - d_head=FLAGS.d_head, - d_inner=FLAGS.d_inner, - dropout=FLAGS.dropout, - dropatt=FLAGS.dropatt, - initializer=initializer, - proj_initializer=proj_initializer, - is_training=is_training, - mem_len=FLAGS.mem_len, - cutoffs=cutoffs, - div_val=FLAGS.div_val, - tie_projs=tie_projs, - input_perms=None, - target_perms=None, - head_target=None, - same_length=FLAGS.same_length, - clamp_len=FLAGS.clamp_len, - use_tpu=False, - untie_r=FLAGS.untie_r, - proj_same_dim=FLAGS.proj_same_dim) - - # number of parameters - num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()]) - tf.logging.info('#params: {}'.format(num_params)) - - # format_str = '{{:<{0}s}}\t{{}}'.format( - # max([len(v.name) for v in tf.trainable_variables()])) - # for v in tf.trainable_variables(): - # tf.logging.info(format_str.format(v.name, v.get_shape())) - - if is_training: - all_vars = tf.trainable_variables() - grads = tf.gradients(loss, all_vars) - grads_and_vars = list(zip(grads, all_vars)) - - return loss, new_mems, grads_and_vars - else: - return loss, new_mems - - return model_fn - - -def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems): - model_fn = get_model_fn( - n_token=n_token, - cutoffs=cutoffs) - - model_ret = model_fn( - inp=inp, - tgt=tgt, - mems=mems, - is_training=is_training) - - return model_ret - - -def train(n_token, cutoffs, ps_device): - ##### Get input function and model function - train_input_fn, train_record_info = data_utils.get_input_fn( - record_info_dir=FLAGS.record_info_dir, - split="train", - per_host_bsz=FLAGS.train_batch_size, - tgt_len=FLAGS.tgt_len, - num_core_per_host=FLAGS.num_core_per_host, - num_hosts=1, - use_tpu=False) - - tf.logging.info("num of batches {}".format(train_record_info["num_batch"])) - - ##### Create computational graph - train_set = train_input_fn({ - "batch_size": FLAGS.train_batch_size, - "data_dir": FLAGS.data_dir}) - - input_feed, label_feed = train_set.make_one_shot_iterator().get_next() - - inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0) - labels = tf.split(label_feed, FLAGS.num_core_per_host, 0) - - per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host - - tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], [] - - for i in range(FLAGS.num_core_per_host): - reuse = True if i > 0 else None - with tf.DEVICE(assign_to_gpu(i, ps_device)), \ - tf.variable_scope(tf.get_variable_scope(), reuse=reuse): - - mems_i = [tf.placeholder(tf.float32, - [FLAGS.mem_len, per_core_bsz, FLAGS.d_model]) - for _ in range(FLAGS.n_layer)] - - loss_i, new_mems_i, grads_and_vars_i = single_core_graph( - n_token=n_token, - cutoffs=cutoffs, - is_training=True, - inp=inputs[i], - tgt=labels[i], - mems=mems_i) - - tower_mems.append(mems_i) - tower_losses.append(loss_i) - tower_new_mems.append(new_mems_i) - tower_grads_and_vars.append(grads_and_vars_i) - - ## average losses and gradients across towers - if len(tower_losses) > 1: - loss = tf.add_n(tower_losses) / len(tower_losses) - grads_and_vars = average_grads_and_vars(tower_grads_and_vars) - else: - loss = tower_losses[0] - grads_and_vars = tower_grads_and_vars[0] - grads, all_vars = zip(*grads_and_vars) - - ## clip gradient - clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip) - grads_and_vars = list(zip(clipped, all_vars)) - - ## configure the optimizer - global_step = tf.train.get_or_create_global_step() - - # warmup stage: increase the learning rate linearly - if FLAGS.warmup_steps > 0: - warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \ - * FLAGS.learning_rate - else: - warmup_lr = 0.0 - - # decay stage: decay the learning rate using the cosine schedule - decay_lr = tf.train.cosine_decay( - FLAGS.learning_rate, - global_step=global_step-FLAGS.warmup_steps, - decay_steps=FLAGS.train_steps-FLAGS.warmup_steps, - alpha=FLAGS.min_lr_ratio) - - # choose warmup or decay - learning_rate = tf.where(global_step < FLAGS.warmup_steps, - warmup_lr, decay_lr) - - # get the train op - optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) - train_op = optimizer.apply_gradients(grads_and_vars, global_step) - - ##### Training loop - tower_mems_np = [ - [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32) - for layer in range(FLAGS.n_layer)] - for core in range(FLAGS.num_core_per_host) - ] - - saver = tf.train.Saver() - - with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: - sess.run(tf.global_variables_initializer()) - - if FLAGS.warm_start_path is not None: - tf.logging.info("warm start from {}".format(FLAGS.warm_start_path)) - saver.restore(sess, FLAGS.warm_start_path) - - fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op] - - total_loss, prev_step = 0., -1 - while True: - feed_dict = {} - for i in range(FLAGS.num_core_per_host): - for m, m_np in zip(tower_mems[i], tower_mems_np[i]): - feed_dict[m] = m_np - - fetched = sess.run(fetches, feed_dict=feed_dict) - - loss_np, tower_mems_np, curr_step = fetched[:3] - total_loss += loss_np - - if curr_step > 0 and curr_step % FLAGS.iterations == 0: - curr_loss = total_loss / (curr_step - prev_step) - tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} " - "| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format( - curr_step, fetched[-3], fetched[-2], - curr_loss, math.exp(curr_loss), curr_loss / math.log(2))) - total_loss, prev_step = 0., curr_step - - if curr_step > 0 and curr_step % FLAGS.save_steps == 0: - save_path = os.path.join(FLAGS.model_dir, "model.ckpt") - saver.save(sess, save_path) - tf.logging.info("Model saved in path: {}".format(save_path)) - - if curr_step == FLAGS.train_steps: - break - - -def evaluate(n_token, cutoffs, ps_device): - ##### Get input function and model function - eval_input_fn, eval_record_info = data_utils.get_input_fn( - record_info_dir=FLAGS.record_info_dir, - split=FLAGS.eval_split, - per_host_bsz=FLAGS.eval_batch_size, - tgt_len=FLAGS.tgt_len, - num_core_per_host=FLAGS.num_core_per_host, - num_hosts=1, - use_tpu=False) - - num_batch = eval_record_info["num_batch"] - if FLAGS.max_eval_batch > 0: - num_batch = FLAGS.max_eval_batch - tf.logging.info("num of batches {}".format(num_batch)) - - ##### Create computational graph - eval_set = eval_input_fn({ - "batch_size": FLAGS.eval_batch_size, - "data_dir": FLAGS.data_dir}) - - input_feed, label_feed = eval_set.make_one_shot_iterator().get_next() - - inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0) - labels = tf.split(label_feed, FLAGS.num_core_per_host, 0) - - per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host - tower_mems, tower_losses, tower_new_mems = [], [], [] - - for i in range(FLAGS.num_core_per_host): - with tf.DEVICE(assign_to_gpu(i, ps_device)), \ - tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): - - mems_i = [tf.placeholder(tf.float32, - [FLAGS.mem_len, per_core_bsz, FLAGS.d_model]) - for _ in range(FLAGS.n_layer)] - - loss_i, new_mems_i = single_core_graph( - n_token=n_token, - cutoffs=cutoffs, - is_training=False, - inp=inputs[i], - tgt=labels[i], - mems=mems_i) - - tower_mems.append(mems_i) - tower_losses.append(loss_i) - tower_new_mems.append(new_mems_i) - - ## sum losses across towers - if len(tower_losses) > 1: - loss = tf.add_n(tower_losses) / len(tower_losses) - else: - loss = tower_losses[0] - - ##### Evaluation loop - tower_mems_np = [ - [np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32) - for layer in range(FLAGS.n_layer)] - for core in range(FLAGS.num_core_per_host) - ] - - saver = tf.train.Saver() - - with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: - sess.run(tf.global_variables_initializer()) - - if FLAGS.eval_ckpt_path is None: - eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir) - else: - eval_ckpt_path = FLAGS.eval_ckpt_path - tf.logging.info("Evaluate {}".format(eval_ckpt_path)) - saver.restore(sess, eval_ckpt_path) - - fetches = [loss, tower_new_mems, tf.size(label_feed)] - - format_str = " >> processing batch {{:{0}d}}/{{:{0}d}} ..".format( - len(str(num_batch))) - - total_loss, total_cnt = 0, 0 - for step in range(num_batch): - if step % (num_batch // 10) == 0: - tf.logging.info(format_str.format(step, num_batch)) - - feed_dict = {} - for i in range(FLAGS.num_core_per_host): - for m, m_np in zip(tower_mems[i], tower_mems_np[i]): - feed_dict[m] = m_np - - fetched = sess.run(fetches, feed_dict=feed_dict) - - loss_np, tower_mems_np, cnt_np = fetched[:3] - total_loss += loss_np * cnt_np - total_cnt += cnt_np - - avg_loss = total_loss / total_cnt - tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format( - avg_loss, math.exp(avg_loss), avg_loss / math.log(2))) - - -def main(unused_argv): - del unused_argv # Unused - - tf.logging.set_verbosity(tf.logging.INFO) - - # Get corpus info - corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path) - n_token = corpus_info["vocab_size"] - cutoffs = corpus_info["cutoffs"][1:-1] - tf.logging.info("n_token {}".format(n_token)) - - if FLAGS.do_train: - train(n_token, cutoffs, "/gpu:0") - if FLAGS.do_eval: - evaluate(n_token, cutoffs, "/gpu:0") - - -if __name__ == "__main__": - tf.app.run() diff --git a/transformer-xl/tf/vocabulary.py b/transformer-xl/tf/vocabulary.py deleted file mode 100644 index 20c728f..0000000 --- a/transformer-xl/tf/vocabulary.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from collections import Counter, OrderedDict - -import numpy as np - -import tensorflow as tf - -from tensorflow.gfile import Open as open -from tensorflow.gfile import Exists as exists - -class Vocab(object): - def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, - delimiter=None, vocab_file=None): - self.counter = Counter() - self.special = special - self.min_freq = min_freq - self.max_size = max_size - self.lower_case = lower_case - self.delimiter = delimiter - self.vocab_file = vocab_file - - def tokenize(self, line, add_eos=False, add_double_eos=False): - line = line.strip() - # convert to lower case - if self.lower_case: - line = line.lower() - - # empty delimiter '' will evaluate False - if self.delimiter == '': - symbols = line - else: - symbols = line.split(self.delimiter) - - if add_double_eos: # lm1b - return [''] + symbols + [''] - elif add_eos: - return symbols + [''] - else: - return symbols - - def count_file(self, path, verbose=False, add_eos=False): - if verbose: print('counting file {} ...'.format(path)) - assert exists(path) - - sents = [] - with open(path, 'r') as f: - for idx, line in enumerate(f): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - symbols = self.tokenize(line, add_eos=add_eos) - self.counter.update(symbols) - sents.append(symbols) - - return sents - - def count_sents(self, sents, verbose=False): - """ - sents : a list of sentences, each a list of tokenized symbols - """ - if verbose: print('counting {} sents ...'.format(len(sents))) - for idx, symbols in enumerate(sents): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - self.counter.update(symbols) - - def _build_from_file(self, vocab_file): - self.idx2sym = [] - self.sym2idx = OrderedDict() - - with open(vocab_file, 'r') as f: - for line in f: - symb = line.strip().split()[0] - self.add_symbol(symb) - self.unk_idx = self.sym2idx[''] - - def build_vocab(self): - if self.vocab_file: - print('building vocab from {}'.format(self.vocab_file)) - self._build_from_file(self.vocab_file) - print('final vocab size {}'.format(len(self))) - else: - print('building vocab with min_freq={}, max_size={}'.format( - self.min_freq, self.max_size)) - self.idx2sym = [] - self.sym2idx = OrderedDict() - - for sym in self.special: - self.add_special(sym) - - for sym, cnt in self.counter.most_common(self.max_size): - if cnt < self.min_freq: break - self.add_symbol(sym) - - print('final vocab size {} from {} unique tokens'.format( - len(self), len(self.counter))) - - def encode_file(self, path, ordered=False, verbose=False, add_eos=True, - add_double_eos=False): - if verbose: print('encoding file {} ...'.format(path)) - assert exists(path) - encoded = [] - with open(path, 'r') as f: - for idx, line in enumerate(f): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - symbols = self.tokenize(line, add_eos=add_eos, - add_double_eos=add_double_eos) - encoded.append(self.convert_to_nparray(symbols)) - - if ordered: - encoded = np.concatenate(encoded) - - return encoded - - def encode_sents(self, sents, ordered=False, verbose=False): - if verbose: print('encoding {} sents ...'.format(len(sents))) - encoded = [] - for idx, symbols in enumerate(sents): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - encoded.append(self.convert_to_nparray(symbols)) - - if ordered: - encoded = np.concatenate(encoded) - - return encoded - - def add_special(self, sym): - if sym not in self.sym2idx: - self.idx2sym.append(sym) - self.sym2idx[sym] = len(self.idx2sym) - 1 - setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) - - def add_symbol(self, sym): - if sym not in self.sym2idx: - self.idx2sym.append(sym) - self.sym2idx[sym] = len(self.idx2sym) - 1 - - def get_sym(self, idx): - assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) - return self.idx2sym[idx] - - def get_idx(self, sym): - if sym in self.sym2idx: - return self.sym2idx[sym] - else: - assert hasattr(self, 'unk_idx') - return self.sym2idx.get(sym, self.unk_idx) - - def get_symbols(self, indices): - return [self.get_sym(idx) for idx in indices] - - def get_indices(self, symbols): - return [self.get_idx(sym) for sym in symbols] - - def convert_to_nparray(self, symbols): - nparray = np.array(self.get_indices(symbols), dtype=np.int64) - return nparray - - def convert_to_sent(self, indices, exclude=None): - if exclude is None: - return ' '.join([self.get_sym(idx) for idx in indices]) - else: - return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) - - def __len__(self): - return len(self.idx2sym)