commit
d0457b6571
87 changed files with 1737 additions and 11132 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
|||
__pycache__
|
||||
.idea/
|
||||
data/
|
||||
data/
|
||||
saved_models/
|
||||
results/
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
3.12
|
||||
3.11
|
||||
|
|
|
|||
26
README.md
26
README.md
|
|
@ -0,0 +1,26 @@
|
|||
# neural compression
|
||||
|
||||
Example usage:
|
||||
|
||||
```shell
|
||||
python main.py --debug train --dataset enwik9 --data-root ~/data/datasets/ml --method optuna --model transformer --model-save-path ~/data/ml-models/test-transformer.pt
|
||||
|
||||
python benchmark.py --debug train --dataset enwik9 --data-root ~/data/datasets/ml --method optuna --model cnn --model-save-path ~/data/ml-models/test-cnn.pt
|
||||
```
|
||||
|
||||
## Running locally
|
||||
|
||||
```
|
||||
uv sync --all-extras
|
||||
```
|
||||
|
||||
## Running on the Ghent University HPC
|
||||
|
||||
See the [Infrastructure docs](https://docs.hpc.ugent.be/infrastructure/#gpu-clusters) for more information about the clusters.
|
||||
|
||||
```
|
||||
module swap cluster/joltik # Specify the (GPU) cluster, {joltik,accelgor,litleo}
|
||||
|
||||
qsub job.pbs # Submit job
|
||||
qstat # Check status
|
||||
```
|
||||
12
benchmark.py
Normal file
12
benchmark.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
from main import main
|
||||
from src.utils.benchmark import execute_benchmark
|
||||
from src.utils.benchmark_dataclasses import BenchmarkItem
|
||||
|
||||
|
||||
def benchmark():
|
||||
# Just calling `main` is the easiest way to allow all functionality
|
||||
execute_benchmark(benchmark_item=BenchmarkItem(task=main, arguments={}), results_dir="results")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark()
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from os.path import join, curdir
|
||||
from typing import Callable
|
||||
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
"""
|
||||
Author: Tibo De Peuter
|
||||
"""
|
||||
class Dataset(TorchDataset, ABC):
|
||||
"""Abstract base class for datasets."""
|
||||
@abstractmethod
|
||||
def __init__(self, root: str, transform: Callable = None):
|
||||
"""
|
||||
:param root: Relative path to the dataset root directory
|
||||
"""
|
||||
self._root: str = join(curdir, 'data', root)
|
||||
self.transform = transform
|
||||
self.dataset = None
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self._root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
from datasets import load_dataset
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
from os.path import curdir, join
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class EnWik9DataSet(Dataset):
|
||||
def __init__(self, root: str = "data", transform: Callable | None = None):
|
||||
super().__init__()
|
||||
self.transform = transform
|
||||
|
||||
# HuggingFace dataset: string text
|
||||
path = join(curdir, root)
|
||||
data = load_dataset("haukur/enwik9", cache_dir=path, split="train")
|
||||
|
||||
# Extract raw text
|
||||
text = data["text"]
|
||||
|
||||
# Convert text (Python string) → bytes → tensor of ints 0–255
|
||||
# UTF-8 but non-ASCII bytes may exceed 255, so enforce modulo or ignore errors
|
||||
byte_data = "".join(text).encode("utf-8", errors="replace")
|
||||
self.data = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
def __len__(self):
|
||||
# number of sliding windows
|
||||
return len(self.data) - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# context window
|
||||
x = self.data[idx : idx + self.context_length]
|
||||
|
||||
# next byte target
|
||||
y = self.data[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from os.path import curdir, join
|
||||
from lorem.text import TextLorem
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class LoremIpsumDataset(Dataset):
|
||||
def __init__(self, root: str = "data", transform: Callable = None):
|
||||
super().__init__(root, transform)
|
||||
|
||||
# Generate text and convert to bytes
|
||||
_lorem = TextLorem()
|
||||
_text = ' '.join(_lorem._word() for _ in range(512))
|
||||
|
||||
path = join(curdir, "data")
|
||||
self._root = path
|
||||
# Convert text to bytes (UTF-8 encoded)
|
||||
self.dataset = torch.tensor([ord(c) % 256 for c in list(_text)], dtype=torch.long)
|
||||
self.context_length = 128
|
||||
|
||||
def __len__(self):
|
||||
# Number of possible sequences of length sequence_length
|
||||
return self.dataset.size(0) - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
x = self.dataset[idx: idx + self.context_length]
|
||||
y = self.dataset[idx + self.context_length]
|
||||
|
||||
if self.transform is not None:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
from .EnWik9 import EnWik9DataSet
|
||||
from .LoremIpsumDataset import LoremIpsumDataset
|
||||
from .Dataset import Dataset
|
||||
37
job.pbs
37
job.pbs
|
|
@ -1,15 +1,40 @@
|
|||
#!/bin/bash
|
||||
|
||||
|
||||
#PBS -N nc-cnn-enwik9-optuna
|
||||
#PBS -l gpus=1
|
||||
#PBS -l walltime=03:00:00
|
||||
#PBS -l walltime=08:00:00
|
||||
#PBS -l mem=60gb
|
||||
#PBS -m abe
|
||||
|
||||
module load PyTorch/2.1.2-foss-2023a-CUDA-12.1.1
|
||||
CACHE_DIR="${VSC_SCRATCH}/.cache" # Directory to use as cache
|
||||
UV_DIR="${VSC_SCRATCH}/uv" # Directory to install packages
|
||||
VENV="${UV_DIR}/venv"
|
||||
|
||||
DATA_DIR="${VSC_DATA}/datasets"
|
||||
RESULTS_DIR="${VSC_DATA}/neural-compression/$( date +%Y%m%d-%H%M-%S%N)-results"
|
||||
|
||||
mkdir -p "${DATA_DIR}" "${RESULTS_DIR}" || true
|
||||
|
||||
module purge
|
||||
module load PyTorch-bundle/2.1.2-foss-2023a-CUDA-12.1.1
|
||||
module load Optuna/3.5.0-foss-2023a
|
||||
module load matplotlib/2.2.5-foss-2023a-Python-2.7.18
|
||||
|
||||
cd $PBS_O_WORKDIR
|
||||
cd "${PBS_O_WORKDIR}" || exit
|
||||
|
||||
source training_env/bin/activate
|
||||
UV_PYTHON_INSTALL_DIR="${UV_DIR}/python" UV_PYTHON_INSTALL_DIR="${UV_DIR}/python" \
|
||||
uv --cache-dir="${CACHE_DIR}/uv" \
|
||||
venv "${VENV}" --clear
|
||||
|
||||
python main_cnn.py --method train
|
||||
source "${VENV}/bin/activate"
|
||||
|
||||
UV_PYTHON_INSTALL_DIR="${UV_DIR}/python" UV_PYTHON_INSTALL_DIR="${UV_DIR}/python" \
|
||||
uv --cache-dir="${CACHE_DIR}/uv" \
|
||||
sync --active --no-dev
|
||||
|
||||
cd "${PBS_O_WORKDIR}" || exit
|
||||
|
||||
python main.py train \
|
||||
--method=optuna \
|
||||
--dataset=enwik9 --data-root="${DATA_DIR}" \
|
||||
--model=cnn --model-save-path="${RESULTS_DIR}/cnn-enwik9-optuna.pt"
|
||||
|
|
|
|||
99
main.py
99
main.py
|
|
@ -1,64 +1,49 @@
|
|||
from argparse import ArgumentParser
|
||||
from math import ceil
|
||||
from src.args import parse_arguments
|
||||
from src.process import compress
|
||||
from src.train import train
|
||||
from src.utils import determine_device
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from dataset_loaders import EnWik9DataSet, LoremIpsumDataset, Dataset
|
||||
from trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
def main():
|
||||
args, print_help = parse_arguments()
|
||||
|
||||
BATCH_SIZE = 64
|
||||
device = determine_device()
|
||||
print(f"Running on device: {device}...")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
DEVICE = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
DEVICE = "mps"
|
||||
else:
|
||||
DEVICE = "cpu"
|
||||
match args.mode:
|
||||
case 'train':
|
||||
size = int(args.size) if args.size else None
|
||||
if args.method == 'optuna':
|
||||
size = 2 ** 12
|
||||
print(f"Using size {size} for optuna (was {args.size})")
|
||||
if args.debug:
|
||||
size = 2 ** 10
|
||||
print(f"Using size {size} for debug (was {args.size})")
|
||||
|
||||
train(
|
||||
device=device,
|
||||
dataset=args.dataset,
|
||||
data_root=args.data_root,
|
||||
n_trials=3 if args.debug else None,
|
||||
size=size,
|
||||
method=args.method,
|
||||
model_name=args.model,
|
||||
model_path=args.model_load_path,
|
||||
model_out=args.model_save_path
|
||||
)
|
||||
|
||||
case 'compress':
|
||||
compress(device=device,
|
||||
model_path=args.model_load_path,
|
||||
input_file=args.input_file,
|
||||
output_file=args.output_file
|
||||
)
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
||||
|
||||
print("Done")
|
||||
|
||||
# hyper parameters
|
||||
context_length = 128
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Running on device: {DEVICE}...")
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--method", choices=["optuna", "train"], required=True)
|
||||
parser.add_argument("--models-path", type=str, required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if args.method == "train":
|
||||
dataset: Dataset = EnWik9DataSet(transform=lambda x: x.to(DEVICE))
|
||||
elif args.method == "optuna":
|
||||
dataset: Dataset = LoremIpsumDataset(transform=lambda x: x.to(DEVICE))
|
||||
else:
|
||||
raise ValueError(f"Unknown method: {args.method}")
|
||||
|
||||
dataset_length = len(dataset)
|
||||
print(f"Dataset size = {dataset_length}")
|
||||
|
||||
training_size = ceil(0.8 * dataset_length)
|
||||
|
||||
print(f"Training set size = {training_size}, Validation set size {dataset_length - training_size}")
|
||||
|
||||
train_set, validate_set = torch.utils.data.random_split(dataset,
|
||||
[training_size, dataset_length - training_size])
|
||||
training_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=BATCH_SIZE, shuffle=False)
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
model = None
|
||||
if args.model_path is not None:
|
||||
print("Loading the models...")
|
||||
model = torch.load(args.model_path)
|
||||
|
||||
trainer: Trainer = OptunaTrainer() if args.method == "optuna" else FullTrainer()
|
||||
|
||||
trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
loss_fn=loss_fn,
|
||||
n_epochs=200,
|
||||
device=DEVICE
|
||||
)
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
from .cnn import CNNPredictor
|
||||
from .transformer import Transformer
|
||||
|
|
@ -1 +0,0 @@
|
|||
from .transformer import Transformer
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class Transformer(nn.Transformer):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
layer_norm_eps=1e-05
|
||||
):
|
||||
super().__init__(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
batch_first=False,
|
||||
norm_first=False,
|
||||
device=None,
|
||||
dtype=None
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
tgt: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
memory_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
src_is_causal: Optional[bool] = None,
|
||||
tgt_is_causal: Optional[bool] = None,
|
||||
memory_is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
return super().forward(
|
||||
src,
|
||||
tgt,
|
||||
src_mask,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
src_key_padding_mask,
|
||||
tgt_key_padding_mask,
|
||||
memory_key_padding_mask,
|
||||
src_is_causal,
|
||||
tgt_is_causal,
|
||||
memory_is_causal,
|
||||
)
|
||||
|
|
@ -3,14 +3,21 @@ name = "project-ml"
|
|||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"datasets>=4.4.1",
|
||||
"datasets>=3.2.0",
|
||||
"huggingface_hub==0.27.0",
|
||||
"fsspec==2024.9.0",
|
||||
"lorem>=0.1.1",
|
||||
"matplotlib>=3.10.7",
|
||||
"numpy>=2.3.4",
|
||||
"optuna>=4.5.0",
|
||||
"torch>=2.9.0",
|
||||
"torchdata>=0.11.0",
|
||||
"torchvision>=0.24.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"hydra-core>=1.3.2",
|
||||
"matplotlib>=3.10.7",
|
||||
"memray>=1.19.1",
|
||||
"optuna==4.5.0",
|
||||
"torch==2.9.0",
|
||||
"torchdata==0.7.1",
|
||||
"torchvision==0.24.0",
|
||||
]
|
||||
|
|
|
|||
Binary file not shown.
45
src/args.py
Normal file
45
src/args.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from argparse import ArgumentParser
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = ArgumentParser(prog="NeuralCompression")
|
||||
parser.add_argument("--debug", "-d", action="store_true", required=False,
|
||||
help="Enable debug mode: smaller datasets, more information")
|
||||
parser.add_argument("--verbose", "-v", action="store_true", required=False,
|
||||
help="Enable verbose mode")
|
||||
|
||||
dataparser = ArgumentParser(add_help=False)
|
||||
dataparser.add_argument("--data-root", type=str, required=False)
|
||||
dataparser.add_argument("--dataset", choices=dataset_called.keys(), required=True)
|
||||
|
||||
modelparser = ArgumentParser(add_help=False)
|
||||
modelparser.add_argument("--model", "-m", type=str, required=False,
|
||||
help="Which model to use")
|
||||
modelparser.add_argument("--model-load-path", type=str, required=False,
|
||||
help="Filepath to the model to load")
|
||||
modelparser.add_argument("--model-save-path", type=str, required=False,
|
||||
help="Filepath to the model to save")
|
||||
|
||||
fileparser = ArgumentParser(add_help=False)
|
||||
fileparser.add_argument("--input-file", "-i", required=False, type=str)
|
||||
fileparser.add_argument("--output-file", "-o", required=False, type=str)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="mode", required=True,
|
||||
help="Mode to run in")
|
||||
|
||||
train_parser = subparsers.add_parser("train",
|
||||
parents=[dataparser, modelparser],
|
||||
help="Do a full training")
|
||||
train_parser.add_argument("--method",
|
||||
choices=["fetch", "optuna", "full"], required=True,
|
||||
help="Method to use for training")
|
||||
train_parser.add_argument("--size", "-s", type=int, required=False,
|
||||
help="Size of the subset of the dataset to use")
|
||||
|
||||
compress_parser = subparsers.add_parser("compress", parents=[modelparser, fileparser])
|
||||
|
||||
decompress_parser = subparsers.add_parser("decompress", parents=[modelparser, fileparser])
|
||||
|
||||
return parser.parse_args(), parser.print_help
|
||||
128
src/dataset_loaders/Dataset.py
Normal file
128
src/dataset_loaders/Dataset.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
from abc import abstractmethod, ABC
|
||||
from itertools import accumulate
|
||||
from os.path import join, curdir
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from tqdm import tqdm
|
||||
|
||||
"""
|
||||
Author: Tibo De Peuter
|
||||
"""
|
||||
|
||||
|
||||
class Dataset(TorchDataset, ABC):
|
||||
"""Abstract base class for datasets."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
name: str,
|
||||
root: str | None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1
|
||||
):
|
||||
"""
|
||||
:param root: Path to the dataset root directory
|
||||
:param split: The dataset split, e.g. 'train', 'validation', 'test'
|
||||
:param size: Override the maximum size of the dataset, useful for debugging
|
||||
"""
|
||||
if root is None:
|
||||
root = join(curdir, 'data')
|
||||
|
||||
self._root = join(root, name)
|
||||
self.split = split
|
||||
self.transform = transform
|
||||
self.size = size
|
||||
self.data = None
|
||||
|
||||
self.chunk_offsets: list[int] = []
|
||||
self.bytes: bytes = bytes()
|
||||
self.tensor: Tensor = torch.tensor([])
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self._root
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def process_data(self):
|
||||
self.chunk_offsets = self.get_offsets()
|
||||
if self.size == -1:
|
||||
# Just use the whole dataset
|
||||
self.bytes = ''.join(tqdm(self.data, desc="Encoding data", leave=False)).encode('utf-8', errors='replace')
|
||||
else:
|
||||
# Use only partition, calculate offsets
|
||||
self.bytes = (''.join(tqdm(self.data[:len(self.chunk_offsets)], desc="Encoding data", leave=False))
|
||||
.encode('utf-8', errors='replace'))
|
||||
|
||||
bytes_array = np.frombuffer(self.bytes, dtype=np.uint8) # Zero-copy
|
||||
self.tensor = torch.from_numpy(bytes_array).to(torch.long, non_blocking=True)
|
||||
|
||||
def get_offsets(self):
|
||||
"""
|
||||
Calculate for each chunk how many bytes came before it
|
||||
"""
|
||||
data = self.data
|
||||
size = self.size
|
||||
|
||||
if size == -1:
|
||||
return [0, *accumulate(tqdm(map(len, data), desc="Calculating offsets", leave=False, total=len(data)))]
|
||||
|
||||
offsets = [0]
|
||||
total = 0
|
||||
append = offsets.append
|
||||
for chunk in tqdm(data):
|
||||
if total >= size:
|
||||
break
|
||||
total += len(chunk)
|
||||
append(total)
|
||||
return offsets
|
||||
|
||||
def get_chunked_item(self, idx: int, offsets: list[int], context_length: int):
|
||||
item = ''
|
||||
|
||||
# Determine first chunk in which item is located
|
||||
chunk_idx = 0
|
||||
while idx >= offsets[chunk_idx]:
|
||||
chunk_idx += 1
|
||||
chunk_idx -= 1
|
||||
|
||||
# Extract item from chunks
|
||||
chunk = str(self.data[chunk_idx])
|
||||
chunk_start = offsets[chunk_idx]
|
||||
|
||||
chunk_item_start = idx - chunk_start
|
||||
item_len_remaining = context_length + 1
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
while chunk_item_start + item_len_remaining > len(chunk):
|
||||
adding_now_len = len(chunk) - chunk_item_start
|
||||
item += chunk[chunk_item_start:]
|
||||
|
||||
chunk_idx += 1
|
||||
chunk = str(self.data[chunk_idx])
|
||||
|
||||
chunk_item_start = 0
|
||||
item_len_remaining -= adding_now_len
|
||||
|
||||
assert len(item) + item_len_remaining == context_length + 1
|
||||
|
||||
item += chunk[chunk_item_start: chunk_item_start + item_len_remaining]
|
||||
|
||||
assert len(item) == context_length + 1, f"Expected item of length {context_length + 1}, was {len(item)}"
|
||||
|
||||
# Transform to tensor
|
||||
data = ''.join(item).encode('utf-8', errors='replace')
|
||||
t = torch.tensor(list(data), dtype=torch.long)
|
||||
x, y = t[:-1], t[-1]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
59
src/dataset_loaders/EnWik9.py
Normal file
59
src/dataset_loaders/EnWik9.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset, Features, Value
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class EnWik9DataSet(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/haukur/enwik9
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable | None = None,
|
||||
size: int = -1
|
||||
):
|
||||
super().__init__('enwik9', root, split, transform, size)
|
||||
|
||||
print(f"Loading from HuggingFace")
|
||||
ft = Features({'text': Value('string')})
|
||||
# Don't pass split here, dataset only contains training
|
||||
text_chunks = load_dataset("haukur/enwik9", cache_dir=self.root, split='train', features=ft)
|
||||
self.data = text_chunks['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
# Define splits manually, because they do not exist in the dataset
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
63
src/dataset_loaders/LoremIpsumDataset.py
Normal file
63
src/dataset_loaders/LoremIpsumDataset.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
from math import ceil
|
||||
from typing import Callable
|
||||
|
||||
from lorem.text import TextLorem
|
||||
from tqdm import tqdm
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class LoremIpsumDataset(Dataset):
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = 2**30
|
||||
):
|
||||
super().__init__('lorem_ipsum', root, split, transform, size)
|
||||
|
||||
_lorem = TextLorem()
|
||||
|
||||
self.data = ' '.join(_lorem._word() for _ in tqdm(range(size), desc="Generating data"))
|
||||
self.size = size
|
||||
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
split_point = ceil(self.chunk_offsets[-1] * 0.8)
|
||||
|
||||
if self.split == 'train':
|
||||
self.start_byte = 0
|
||||
self.end_byte = split_point
|
||||
elif self.split == 'validation':
|
||||
self.start_byte = split_point
|
||||
self.end_byte = self.chunk_offsets[-1]
|
||||
else:
|
||||
raise ValueError("split must be 'train' or 'validation'")
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
return self.end_byte - self.start_byte - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Get sequence of characters
|
||||
# x_str = self.text[idx: idx + self.context_length]
|
||||
# y_char = self.text[idx + self.context_length]
|
||||
#
|
||||
# # Convert to tensors
|
||||
# x = torch.tensor([ord(c) % 256 for c in x_str], dtype=torch.long)
|
||||
# y = torch.tensor(ord(y_char) % 256, dtype=torch.long)
|
||||
#
|
||||
# if self.transform is not None:
|
||||
# x = self.transform(x)
|
||||
#
|
||||
# return x, y
|
||||
x = self.tensor[self.start_byte + idx:self.start_byte + idx + self.context_length]
|
||||
y = self.tensor[self.start_byte + idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
51
src/dataset_loaders/OpenGenomeDataset.py
Normal file
51
src/dataset_loaders/OpenGenomeDataset.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Callable
|
||||
|
||||
from datasets import load_dataset, Value, Features
|
||||
|
||||
from .Dataset import Dataset
|
||||
|
||||
|
||||
class OpenGenomeDataset(Dataset):
|
||||
"""
|
||||
Hugging Face: https://huggingface.co/datasets/LongSafari/open-genome
|
||||
|
||||
:param split Either 'train', 'test' or 'validation'
|
||||
:param stage Either 'sample', 'stage1' or 'stage2'.
|
||||
'sample' only provides a 'validation' split
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: str | None = None,
|
||||
split: str = 'train',
|
||||
transform: Callable = None,
|
||||
size: int = -1,
|
||||
stage: str = 'stage2'
|
||||
):
|
||||
super().__init__('open_genome', root, split, transform, size)
|
||||
|
||||
print(f"Loading from HuggingFace (stage: {stage}, split: {split})")
|
||||
ft = Features({'text': Value('string')})
|
||||
data = load_dataset("LongSafari/open-genome", stage, split=split, cache_dir=self.root, features=ft)
|
||||
self.data = data['text']
|
||||
self.size = size
|
||||
|
||||
# Model uses fixed 128-length context
|
||||
self.context_length = 128
|
||||
|
||||
self.process_data()
|
||||
|
||||
print("Done initializing dataset")
|
||||
|
||||
def __len__(self):
|
||||
# return len(self.data) - self.context_length
|
||||
return self.chunk_offsets[-1] - self.context_length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# return self.get_chunked_item(idx, self.chunk_offsets, self.context_length)
|
||||
x = self.tensor[idx:idx + self.context_length]
|
||||
y = self.tensor[idx + self.context_length]
|
||||
|
||||
if self.transform:
|
||||
x = self.transform(x)
|
||||
|
||||
return x, y
|
||||
10
src/dataset_loaders/__init__.py
Normal file
10
src/dataset_loaders/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from .Dataset import Dataset
|
||||
from .EnWik9 import EnWik9DataSet
|
||||
from .LoremIpsumDataset import LoremIpsumDataset
|
||||
from .OpenGenomeDataset import OpenGenomeDataset
|
||||
|
||||
dataset_called: dict[str, type[Dataset]] = {
|
||||
'enwik9': EnWik9DataSet,
|
||||
'lorem_ipsum': LoremIpsumDataset,
|
||||
'opengenome': OpenGenomeDataset
|
||||
}
|
||||
14
src/models/Model.py
Normal file
14
src/models/Model.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Model(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, loss_function = None):
|
||||
super().__init__()
|
||||
self._loss_function = loss_function
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
return self._loss_function
|
||||
8
src/models/__init__.py
Normal file
8
src/models/__init__.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
from .Model import Model
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import ByteTransformer
|
||||
|
||||
model_called: dict[str, type[Model]] = {
|
||||
'cnn': CNNPredictor,
|
||||
'transformer': ByteTransformer
|
||||
}
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class CNNPredictor(Model):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(nn.CrossEntropyLoss())
|
||||
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
1
src/models/transformer/__init__.py
Normal file
1
src/models/transformer/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .transformer import ByteTransformer
|
||||
70
src/models/transformer/transformer.py
Normal file
70
src/models/transformer/transformer.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, arange
|
||||
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class LearnedPositionalEncoding(Model):
|
||||
def __init__(self, max_len, d_model):
|
||||
super().__init__()
|
||||
self.pos_emb = nn.Embedding(max_len, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [seq, batch, d_model]
|
||||
seq_len = x.size(0)
|
||||
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
|
||||
return x + self.pos_emb(positions) # broadcast over batch
|
||||
|
||||
class ByteTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
layer_norm_eps=1e-05,
|
||||
max_len=128
|
||||
):
|
||||
super().__init__()
|
||||
self.src_embedding = nn.Embedding(256, d_model)
|
||||
self.tgt_embedding = nn.Embedding(256, d_model)
|
||||
|
||||
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
batch_first=False,
|
||||
norm_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(d_model, 256)
|
||||
|
||||
self.loss_function = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
tgt: Tensor,
|
||||
) -> Tensor:
|
||||
src_embeds = self.src_embedding(src)
|
||||
tgt_embeds = self.tgt_embedding(tgt)
|
||||
|
||||
src_pos = self.src_pos(src_embeds)
|
||||
tgt_pos = self.tgt_pos(tgt_embeds)
|
||||
|
||||
return self.output_proj(self.transformer(src_pos, tgt_pos))
|
||||
30
src/process.py
Normal file
30
src/process.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import torch
|
||||
|
||||
|
||||
def compress(
|
||||
device,
|
||||
model_path: str,
|
||||
output_file: str,
|
||||
input_file: str | None = None
|
||||
):
|
||||
# Get input to compress
|
||||
if input_file:
|
||||
with open(input_file, "rb") as file:
|
||||
byte_data = file.read()
|
||||
else:
|
||||
# Read from stdin
|
||||
text = input()
|
||||
byte_data = text.encode('utf-8', errors='replace')
|
||||
|
||||
tensor = torch.tensor(list(byte_data), dtype=torch.long)
|
||||
print(tensor)
|
||||
|
||||
# Get model
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
# TODO Feed to model for compression, store result
|
||||
return
|
||||
|
||||
|
||||
def decompress():
|
||||
return NotImplementedError("Decompression is not implemented yet")
|
||||
74
src/train.py
Normal file
74
src/train.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
from src.models import model_called
|
||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
|
||||
|
||||
def train(
|
||||
device,
|
||||
dataset: str,
|
||||
data_root: str,
|
||||
n_trials: int | None = None,
|
||||
size: int | None = None,
|
||||
method: str = 'optuna',
|
||||
model_name: str | None = None,
|
||||
model_path: str | None = None,
|
||||
model_out: str | None = None
|
||||
):
|
||||
batch_size = 2
|
||||
|
||||
assert model_name or model_path, "Either a model to train or a model to load from model_path must be provided"
|
||||
|
||||
if model_name:
|
||||
print("Creating model")
|
||||
model = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
'transform': lambda x: x.to(device),
|
||||
}
|
||||
|
||||
if size:
|
||||
dataset_common_args['size'] = size
|
||||
|
||||
print("Loading in the dataset...")
|
||||
if dataset in dataset_called:
|
||||
training_set = dataset_called[dataset](split='train', **dataset_common_args)
|
||||
validate_set = dataset_called[dataset](split='validation', **dataset_common_args)
|
||||
else:
|
||||
# TODO Allow to import arbitrary files
|
||||
raise NotImplementedError(f"Importing external datasets is not implemented yet")
|
||||
|
||||
if method == 'fetch':
|
||||
# TODO More to earlier in chain, because now everything is converted into tensors as well?
|
||||
exit(0)
|
||||
|
||||
print(f"Training set size = {len(training_set)}, Validation set size {len(validate_set)}")
|
||||
training_loader = DataLoader(training_set, batch_size=batch_size, shuffle=True)
|
||||
validation_loader = DataLoader(validate_set, batch_size=batch_size, shuffle=False)
|
||||
|
||||
trainer: Trainer = OptunaTrainer(n_trials=n_trials) if method == "optuna" else FullTrainer()
|
||||
|
||||
print("Training")
|
||||
best_model = trainer.execute(
|
||||
model=model,
|
||||
train_loader=training_loader,
|
||||
validation_loader=validation_loader,
|
||||
n_epochs=n_trials,
|
||||
device=device
|
||||
)
|
||||
|
||||
print("Saving model...")
|
||||
f = model_out or f"saved_models/{model.__class__.__name__}.pt"
|
||||
# Make sure path exists
|
||||
Path(f).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(best_model, f)
|
||||
print(f"Saved model to '{f}'")
|
||||
|
||||
|
|
@ -1,26 +1,26 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model
|
||||
from ..utils import print_losses
|
||||
|
||||
|
||||
class FullTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> None:
|
||||
) -> nn.Module:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, loss_fn, n_epochs)
|
||||
print_losses(train_loss, val_loss)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, device=device)
|
||||
print_losses(train_loss, val_loss)
|
||||
|
||||
return model
|
||||
72
src/trainers/OptunaTrainer.py
Normal file
72
src/trainers/OptunaTrainer.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import optuna
|
||||
import optuna.trial as tr
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, model: nn.Module):
|
||||
match model.__class__:
|
||||
case CNNPredictor.__class__:
|
||||
return model(
|
||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
case ByteTransformer.__class__:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return model(
|
||||
d_model=128, # hard coded for now as data loaders provide fixed (B, 128) tensors
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
model: Model,
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial, model).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, model.loss_function, device=device)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def __init__(self, n_trials: int | None = None):
|
||||
super().__init__()
|
||||
self.n_trials = n_trials if n_trials else 20
|
||||
print(f"Creating Optuna trainer(n_trials = {self.n_trials})")
|
||||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
study = optuna.create_study(direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, model, device),
|
||||
n_trials=self.n_trials
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = model(
|
||||
**best_params
|
||||
)
|
||||
|
||||
return best_model
|
||||
|
|
@ -1,38 +1,60 @@
|
|||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from typing import Callable
|
||||
|
||||
from ..models import ByteTransformer, Model
|
||||
|
||||
|
||||
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
||||
if isinstance(model, ByteTransformer):
|
||||
tgt_in = torch.cat([
|
||||
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
|
||||
x[:, :-1]
|
||||
], dim=1)
|
||||
logits = model(x, tgt_in)
|
||||
|
||||
# only consider the last time step of the model where the full context
|
||||
# is available
|
||||
return logits[:, -1, :]
|
||||
return model(x)
|
||||
|
||||
|
||||
def train(
|
||||
model: nn.Module,
|
||||
model: Model,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
epochs: int = 100,
|
||||
loss_fn: Callable,
|
||||
epochs: int | None = None,
|
||||
learning_rate: float = 1e-3,
|
||||
weight_decay: float = 1e-8,
|
||||
device="cuda"
|
||||
) -> tuple[list[float], list[float]]:
|
||||
|
||||
model.to(device)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||
|
||||
|
||||
avg_training_losses = []
|
||||
avg_validation_losses = []
|
||||
|
||||
if epochs is None:
|
||||
epochs = 100
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
model.train()
|
||||
total_loss = []
|
||||
|
||||
for x, y in tqdm(training_loader):
|
||||
x = x.long().to(device) # important for Embedding
|
||||
y = y.long().to(device) # must be (B,) for CE
|
||||
# size (B, 128)
|
||||
x = x.long().to(device)
|
||||
|
||||
# size (B)
|
||||
y = y.long().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
logits = model(x) # (B, 256)
|
||||
logits = _forward(model, x, device)
|
||||
|
||||
loss = loss_fn(logits, y)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
@ -49,7 +71,7 @@ def train(
|
|||
x = x.long().to(device)
|
||||
y = y.long().to(device)
|
||||
|
||||
logits = model(x)
|
||||
logits = _forward(model, x, device)
|
||||
loss = loss_fn(logits, y)
|
||||
losses.append(loss.item())
|
||||
|
||||
|
|
@ -1,7 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
|
@ -15,8 +13,7 @@ class Trainer(ABC):
|
|||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> None:
|
||||
pass
|
||||
) -> nn.Module:
|
||||
pass
|
||||
175
src/utils/benchmark.py
Normal file
175
src/utils/benchmark.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
"""Utilities functions for benchmarking."""
|
||||
import json
|
||||
import string
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from logging import getLogger
|
||||
from os import getpid, path
|
||||
from pathlib import Path
|
||||
from random import choices
|
||||
from subprocess import DEVNULL, PIPE, CalledProcessError, TimeoutExpired, run
|
||||
from timeit import timeit
|
||||
from typing import Callable
|
||||
|
||||
from memray import Tracker
|
||||
|
||||
from ..utils.benchmark_dataclasses import BenchmarkItem, BenchmarkResult
|
||||
|
||||
log = getLogger(__name__)
|
||||
|
||||
|
||||
def get_commit_hash() -> str:
|
||||
"""
|
||||
Get the commit hash of the current git repository.
|
||||
|
||||
If not working in a git repository, return a random string that looks like a commit hash.
|
||||
"""
|
||||
try:
|
||||
return run(
|
||||
["git", "rev-parse", "--short", "HEAD"],
|
||||
check=True,
|
||||
stdout=PIPE,
|
||||
stderr=DEVNULL,
|
||||
text=True,
|
||||
).stdout.strip()
|
||||
except CalledProcessError as e:
|
||||
log.error(
|
||||
"Could not determine the commit hash. Are you using a git repository?:\n%s",
|
||||
e,
|
||||
)
|
||||
log.error("Using a random string as commit hash.")
|
||||
return "".join(choices(string.hexdigits[:-6], k=40))
|
||||
|
||||
|
||||
def init_stat_file(stat_file: Path, header: str) -> int:
|
||||
"""Initialize a statistics file with a header."""
|
||||
# Check if the parent directory exists
|
||||
stat_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Check if the file exists
|
||||
if stat_file.exists():
|
||||
# Nothing left to do
|
||||
return 0
|
||||
|
||||
# Initialize the file by writing the header to it.
|
||||
log.debug("Initializing statistics file %s", stat_file)
|
||||
stat_file.touch()
|
||||
stat_file.write_text(f"{header}\n", encoding="utf-8")
|
||||
return 1
|
||||
|
||||
|
||||
def track_time_memory(task: Callable, result: BenchmarkResult, mem_file: Path, mem_json_file: Path):
|
||||
"""Track the time and memory consumption of a task."""
|
||||
|
||||
def task_with_result():
|
||||
result.value = task()
|
||||
|
||||
# Measure memory consumption
|
||||
with Tracker(file_name=mem_file, native_traces=True, follow_fork=True, memory_interval_ms=1):
|
||||
try:
|
||||
# Measure runtime
|
||||
result.runtime = timeit(task_with_result, number=1, globals=globals())
|
||||
except BaseException as e:
|
||||
log.error("Error while timing the program:\n%s", e, exc_info=True)
|
||||
return None
|
||||
|
||||
# Convert binary memory file into JSON.
|
||||
try:
|
||||
run(
|
||||
[
|
||||
"python",
|
||||
"-m",
|
||||
"memray",
|
||||
"stats",
|
||||
"--json",
|
||||
"--num-largest",
|
||||
"1",
|
||||
"--output",
|
||||
mem_json_file,
|
||||
mem_file,
|
||||
],
|
||||
check=True,
|
||||
timeout=100,
|
||||
stdout=DEVNULL,
|
||||
)
|
||||
# Parse JSON to get peak_memory
|
||||
mem_results = json.loads(mem_json_file.read_text(encoding="utf-8"))
|
||||
result.peak_memory = mem_results["metadata"]["peak_memory"]
|
||||
|
||||
except CalledProcessError as e:
|
||||
log.error(
|
||||
"Something went wrong while processing the memray memory file %s:\n%s",
|
||||
mem_file,
|
||||
e,
|
||||
)
|
||||
except TimeoutExpired as e:
|
||||
log.error(
|
||||
"Timeout expired while processing the memray memory file %s:\n%s}",
|
||||
mem_file,
|
||||
e,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def execute_benchmark(
|
||||
benchmark_item: BenchmarkItem,
|
||||
results_dir: str | Path,
|
||||
timeout: int = 100,
|
||||
) -> BenchmarkResult:
|
||||
"""Execute a benchmark and track its runtime and peak memory consumption."""
|
||||
mem_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.mem"))
|
||||
mem_json_file = Path(path.join(results_dir, f"memray-{benchmark_item.task.__name__}.json"))
|
||||
|
||||
result = BenchmarkResult(benchmark_item)
|
||||
|
||||
try:
|
||||
# Time and track memory usage
|
||||
# Kill after timeout in seconds
|
||||
with ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: track_time_memory(
|
||||
lambda: benchmark_item.task(**benchmark_item.arguments), result, mem_file, mem_json_file
|
||||
)
|
||||
)
|
||||
executed_result = future.result(timeout=timeout)
|
||||
|
||||
if executed_result is not None:
|
||||
result = executed_result
|
||||
|
||||
log.info(
|
||||
"PID %d: %s finished [%.6f seconds, %d bytes]",
|
||||
getpid(),
|
||||
benchmark_item.get_method(),
|
||||
result.runtime,
|
||||
result.peak_memory,
|
||||
)
|
||||
except TimeoutError:
|
||||
log.error("Timeout expired while running the benchmark_suite, cleaning up now.")
|
||||
|
||||
log.info(
|
||||
"PID %d: %s failed after timeout (%d seconds)",
|
||||
getpid(),
|
||||
benchmark_item.get_method(),
|
||||
timeout,
|
||||
)
|
||||
finally:
|
||||
# Clean up memory dump file to save disk space.
|
||||
mem_file.unlink()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import hydra
|
||||
|
||||
# Dummy example, read the contents of the dataset
|
||||
def _read_contents(filename):
|
||||
with open(filename, encoding="utf-8") as f:
|
||||
log.info("Dataset content: %s", f.read())
|
||||
|
||||
def _read_contents_wrapper(cfg):
|
||||
return _read_contents(cfg.dataset.path)
|
||||
|
||||
hydra_wrapped = hydra.main(config_path="../../config", config_name="config", version_base="1.2")(
|
||||
_read_contents_wrapper
|
||||
)()
|
||||
79
src/utils/benchmark_dataclasses.py
Normal file
79
src/utils/benchmark_dataclasses.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""
|
||||
Benchmark data classes.
|
||||
|
||||
This module contains the BenchmarkResult class which is used to store and print the results of a
|
||||
benchmark_suite.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class BenchmarkItem:
|
||||
"""A class used to represent a benchmark_suite (iteration)."""
|
||||
|
||||
task: Callable
|
||||
arguments: dict
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the BenchmarkItem object."""
|
||||
return self.get_in_data_format()
|
||||
|
||||
def get_method(self) -> str:
|
||||
"""
|
||||
Format the method as if it were a function call.
|
||||
"""
|
||||
method_name = self.task.__name__
|
||||
arguments = ", ".join(
|
||||
f'{key}={str(value)[:15]}'
|
||||
for key, value in self.arguments.items()
|
||||
)
|
||||
return f"{method_name}({arguments})"
|
||||
|
||||
def get_in_data_format(self) -> str:
|
||||
"""
|
||||
Format the benchmark_suite item to be printed to a .dat file.
|
||||
"""
|
||||
# Flatten out arguments
|
||||
values = list(self.__dict__.values())
|
||||
values[1:2] = values[1].values()
|
||||
|
||||
return " ".join(map(str, values))
|
||||
|
||||
def get_header(self) -> str:
|
||||
"""
|
||||
Returns the header which is just the names of the fields separated by spaces.
|
||||
"""
|
||||
return " ".join(self.__dict__.keys())
|
||||
|
||||
|
||||
@dataclass(init=True)
|
||||
class BenchmarkResult:
|
||||
"""A class used to represent the result of a benchmark_suite."""
|
||||
|
||||
benchmark_item: BenchmarkItem
|
||||
runtime: float = 0
|
||||
peak_memory: int = 0
|
||||
value: Any = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the BenchmarkResult object."""
|
||||
return self.get_in_data_format()
|
||||
|
||||
def get_in_data_format(self) -> str:
|
||||
"""
|
||||
Format the benchmark_suite result to be printed to a .dat file.
|
||||
"""
|
||||
return " ".join(map(str, self.__dict__.values()))
|
||||
|
||||
def get_header(self) -> str:
|
||||
"""
|
||||
Returns the header which is just the names of the fields separated by spaces.
|
||||
"""
|
||||
# Get header of the BenchmarkItem
|
||||
keys = list(self.__annotations__.keys())
|
||||
keys[0:1] = self.benchmark_item.__annotations__.keys()
|
||||
keys[1:2] = self.benchmark_item.arguments.keys()
|
||||
|
||||
return " ".join(keys)
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
from os import path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
||||
|
|
@ -10,11 +12,13 @@ def make_context_pairs(data: bytes, context_length: int) -> TensorDataset:
|
|||
y = data[context_length:]
|
||||
return TensorDataset(x, y)
|
||||
|
||||
|
||||
def print_distribution(from_to: tuple[int, int], probabilities: list[float]):
|
||||
plt.hist(range(from_to[0], from_to[1]), weights=probabilities)
|
||||
plt.show()
|
||||
|
||||
def print_losses(train_losses: list[float], validation_losses: list[float], show=False):
|
||||
|
||||
def print_losses(train_losses: list[float], validation_losses: list[float], filename: str | None = None, show=False):
|
||||
plt.plot(train_losses, label="Training loss")
|
||||
plt.plot(validation_losses, label="Validation loss")
|
||||
plt.xlabel("Epoch")
|
||||
|
|
@ -23,7 +27,26 @@ def print_losses(train_losses: list[float], validation_losses: list[float], show
|
|||
|
||||
if show:
|
||||
plt.show()
|
||||
plt.savefig("losses.png")
|
||||
|
||||
if filename is None:
|
||||
filename = path.join("results", "losses.png")
|
||||
|
||||
print(f"Saving losses to {filename}...")
|
||||
plt.savefig(filename)
|
||||
|
||||
|
||||
def determine_device():
|
||||
# NVIDIA GPUs (most HPC clusters)
|
||||
if torch.cuda.is_available():
|
||||
return torch.device("cuda")
|
||||
# Apple Silicon (macOS)
|
||||
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
# Intel GPUs (oneAPI)
|
||||
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||
return torch.device("xpu")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def load_data(path: str) -> bytes:
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
from typing import Callable
|
||||
|
||||
import optuna
|
||||
import optuna.trial as tr
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .trainer import Trainer
|
||||
from ..models.cnn import CNNPredictor
|
||||
from .train import train
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, vocab_size: int = 256):
|
||||
hidden_dim = trial.suggest_int("hidden_dim", 64, 512, log=True)
|
||||
embedding_dim = trial.suggest_int("embed_dim", 64, 512, log=True)
|
||||
|
||||
return CNNPredictor(
|
||||
vocab_size=vocab_size,
|
||||
hidden_dim=hidden_dim,
|
||||
embed_dim=embedding_dim,
|
||||
)
|
||||
|
||||
|
||||
def objective_function(
|
||||
trial: tr.Trial,
|
||||
training_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
device: str
|
||||
):
|
||||
model = create_model(trial).to(device)
|
||||
_, validation_loss = train(model, training_loader, validation_loader, loss_fn)
|
||||
return min(validation_loss)
|
||||
|
||||
|
||||
class OptunaTrainer(Trainer):
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | None,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
|
||||
n_epochs: int,
|
||||
device: str
|
||||
) -> None:
|
||||
study = optuna.create_study(study_name="CNN network", direction="minimize")
|
||||
study.optimize(
|
||||
lambda trial: objective_function(trial, train_loader, validation_loader, loss_fn, device),
|
||||
n_trials=20
|
||||
)
|
||||
|
||||
best_params = study.best_trial.params
|
||||
best_model = CNNPredictor(
|
||||
**best_params
|
||||
)
|
||||
torch.save(best_model, f"saved_models/{model.__class__.__name__}.pt")
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
# Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context
|
||||
|
||||
This repository contains the code in both **PyTorch** and **TensorFlow** for our paper
|
||||
>[Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](http://arxiv.org/abs/1901.02860)
|
||||
|
||||
>Zihang Dai\*, Zhilin Yang\*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov (*: equal contribution)
|
||||
|
||||
>Preprint 2018
|
||||
|
||||
## TensorFlow
|
||||
|
||||
- The source code is in the `tf/` folder, supporting (1) single-node multi-gpu training, and (2) multi-host TPU training.
|
||||
- Besides the source code, we also provide pretrained "TensorFlow" models with state-of-the-art (SoTA) performances reported in the paper.
|
||||
- Please refer to `tf/README.md` for details.
|
||||
|
||||
## PyTorch
|
||||
|
||||
- The source code is in the `pytorch/` folder, supporting single-node multi-gpu training via the module `nn.DataParallel`.
|
||||
- Please refer to `pytorch/README.md` for details.
|
||||
|
||||
## Results
|
||||
|
||||
Transformer-XL achieves new state-of-the-art results on multiple language modeling benchmarks. Transformer-XL is also the first to break through the 1.0 barrier on char-level language modeling. Below is a summary.
|
||||
|
||||
Method | enwiki8 | text8 | One Billion Word | WT-103 | PTB (w/o finetuning)
|
||||
-- | -- | -- | -- | -- | --
|
||||
Previous Best | 1.06 | 1.13 | 23.7 | 20.5 | 55.5
|
||||
Transformer-XL | **0.99** | **1.08** | **21.8** | **18.3** | **54.5**
|
||||
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
A large portion of the `getdata.sh` script comes from the [awd-lstm](https://github.com/salesforce/awd-lstm-lm/) repo. Happy Language Modeling :)
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
echo "=== Acquiring datasets ==="
|
||||
echo "---"
|
||||
|
||||
mkdir -p data
|
||||
cd data
|
||||
|
||||
if [[ ! -d 'wikitext-2' ]]; then
|
||||
echo "- Downloading WikiText-2 (WT2)"
|
||||
wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
|
||||
unzip -q wikitext-2-v1.zip
|
||||
cd wikitext-2
|
||||
mv wiki.train.tokens train.txt
|
||||
mv wiki.valid.tokens valid.txt
|
||||
mv wiki.test.tokens test.txt
|
||||
cd ..
|
||||
fi
|
||||
|
||||
echo "- Downloading WikiText-103 (WT2)"
|
||||
if [[ ! -d 'wikitext-103' ]]; then
|
||||
wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
|
||||
unzip -q wikitext-103-v1.zip
|
||||
cd wikitext-103
|
||||
mv wiki.train.tokens train.txt
|
||||
mv wiki.valid.tokens valid.txt
|
||||
mv wiki.test.tokens test.txt
|
||||
cd ..
|
||||
fi
|
||||
|
||||
echo "- Downloading enwik8 (Character)"
|
||||
if [[ ! -d 'enwik8' ]]; then
|
||||
mkdir -p enwik8
|
||||
cd enwik8
|
||||
wget --continue http://mattmahoney.net/dc/enwik8.zip
|
||||
wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
|
||||
python3 prep_enwik8.py
|
||||
cd ..
|
||||
fi
|
||||
|
||||
echo "- Downloading text8 (Character)"
|
||||
if [[ ! -d 'text8' ]]; then
|
||||
mkdir -p text8
|
||||
cd text8
|
||||
wget --continue http://mattmahoney.net/dc/text8.zip
|
||||
python ../../prep_text8.py
|
||||
cd ..
|
||||
fi
|
||||
|
||||
echo "- Downloading Penn Treebank (PTB)"
|
||||
if [[ ! -d 'penn' ]]; then
|
||||
wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
tar -xzf simple-examples.tgz
|
||||
|
||||
mkdir -p penn
|
||||
cd penn
|
||||
mv ../simple-examples/data/ptb.train.txt train.txt
|
||||
mv ../simple-examples/data/ptb.test.txt test.txt
|
||||
mv ../simple-examples/data/ptb.valid.txt valid.txt
|
||||
cd ..
|
||||
|
||||
echo "- Downloading Penn Treebank (Character)"
|
||||
mkdir -p pennchar
|
||||
cd pennchar
|
||||
mv ../simple-examples/data/ptb.char.train.txt train.txt
|
||||
mv ../simple-examples/data/ptb.char.test.txt test.txt
|
||||
mv ../simple-examples/data/ptb.char.valid.txt valid.txt
|
||||
cd ..
|
||||
|
||||
rm -rf simple-examples/
|
||||
fi
|
||||
|
||||
echo "- Downloading 1B words"
|
||||
|
||||
if [[ ! -d 'one-billion-words' ]]; then
|
||||
mkdir -p one-billion-words
|
||||
cd one-billion-words
|
||||
|
||||
wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
|
||||
tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
|
||||
|
||||
path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
|
||||
cat ${path}/news.en.heldout-00000-of-00050 > valid.txt
|
||||
cat ${path}/news.en.heldout-00000-of-00050 > test.txt
|
||||
|
||||
wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
|
||||
|
||||
cd ..
|
||||
fi
|
||||
|
||||
echo "---"
|
||||
echo "Happy language modeling :)"
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
import os
|
||||
import sys
|
||||
import zipfile
|
||||
|
||||
from io import open
|
||||
|
||||
if os.path.exists('train.txt'):
|
||||
print('Tokenized text8 already exists - skipping processing')
|
||||
sys.exit()
|
||||
|
||||
data = zipfile.ZipFile('text8.zip').extractall()
|
||||
data = open('text8', 'r', encoding='utf-8').read()
|
||||
|
||||
print('Length of text8: {}'.format(len(data)))
|
||||
|
||||
num_test_chars = 5000000
|
||||
|
||||
train_data = data[: -2 * num_test_chars]
|
||||
valid_data = data[-2 * num_test_chars: -num_test_chars]
|
||||
test_data = data[-num_test_chars:]
|
||||
|
||||
for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]:
|
||||
print('{} will have {} bytes'.format(fn, len(part)))
|
||||
print('- Tokenizing...')
|
||||
# Change space ' ' to underscore '_'
|
||||
part_str = ' '.join(['_' if c == ' ' else c for c in part.strip()])
|
||||
print('- Writing...')
|
||||
f = open(fn, 'w').write(part_str)
|
||||
f = open(fn + '.raw', 'w', encoding='utf-8').write(part)
|
||||
BIN
transformer-xl/pytorch/.DS_Store
vendored
BIN
transformer-xl/pytorch/.DS_Store
vendored
Binary file not shown.
|
|
@ -1,62 +0,0 @@
|
|||
## Introduction
|
||||
|
||||
This directory contains our pytorch implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our pytorch codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts:
|
||||
- `*large.sh` are for the SoTA setting with large models which might not be directly runnable on a local GPU machine.
|
||||
- `*base.sh` are for the base models which can be run on a few GPUs.
|
||||
|
||||
The pytorch implementation produces similar results to the TF codebase under the same settings in our preliminary experiments.
|
||||
|
||||
|
||||
## Prerequisite
|
||||
|
||||
- Pytorch 0.4: `conda install pytorch torchvision -c pytorch`
|
||||
|
||||
|
||||
## Data Prepration
|
||||
|
||||
`bash getdata.sh`
|
||||
|
||||
## Training and Evaluation
|
||||
|
||||
#### Replicate the "bpc = 1.06" result on `enwik8` with a 12-layer Transformer-XL
|
||||
|
||||
- Make sure the machine have **4 GPUs**, each with **at least 11G memory**
|
||||
|
||||
- Training
|
||||
|
||||
`bash run_enwik8_base.sh train --work_dir PATH_TO_WORK_DIR`
|
||||
|
||||
- Evaluation
|
||||
|
||||
`bash run_enwik8_base.sh eval --work_dir PATH_TO_WORK_DIR`
|
||||
|
||||
|
||||
|
||||
#### Replicate the "PPL = 24.03" result on `wikitext-103` with Transformer-XL
|
||||
|
||||
- Make sure the machine have **4 GPUs**, each with **at least 11G memory**
|
||||
|
||||
- Training
|
||||
|
||||
`bash run_wt103_base.sh train --work_dir PATH_TO_WORK_DIR`
|
||||
|
||||
- Evaluation
|
||||
|
||||
`bash run_wt103_base.sh eval --work_dir PATH_TO_WORK_DIR`
|
||||
|
||||
|
||||
|
||||
#### Other options:
|
||||
|
||||
- `--batch_chunk`: this option allows one to trade speed for memory. For `batch_chunk > 1`, the program will split each training batch into `batch_chunk` sub-batches and perform forward and backward on each sub-batch sequentially, with the gradient accumulated and divided by `batch_chunk`. Hence, the memory usage will propertionally lower while the computation time will inversely higher.
|
||||
- `--div_val`: when using adaptive softmax and embedding, the embedding dimension is divided by `div_val` from bin $i$ to bin $i+1$. This saves both GPU memory and the parameter budget.
|
||||
- `--fp16` and `--dynamic-loss-scale`: Run in pseudo-fp16 mode (fp16 storage fp32 math) with dynamic loss scaling.
|
||||
- Note: to explore the `--fp16` option, please make sure the `apex` package is installed (https://github.com/NVIDIA/apex/).
|
||||
- To see performance without the recurrence mechanism, simply use `mem_len=0` in all your scripts.
|
||||
- To see performance of a standard Transformer without relative positional encodings or recurrence mechanisms, use `attn_type=2` and `mem_len=0`.
|
||||
|
||||
|
||||
#### Other datasets:
|
||||
|
||||
- `Text8` character-level language modeling: check out `run_text8_base.sh`
|
||||
- `lm1b` word-level language modeling: check out `run_lm1b_base.sh`
|
||||
|
|
@ -1,273 +0,0 @@
|
|||
import os, sys
|
||||
import glob
|
||||
|
||||
from collections import Counter, OrderedDict
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from utils.vocabulary import Vocab
|
||||
|
||||
class LMOrderedIterator(object):
|
||||
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
|
||||
"""
|
||||
data -- LongTensor -- the LongTensor is strictly ordered
|
||||
"""
|
||||
self.bsz = bsz
|
||||
self.bptt = bptt
|
||||
self.ext_len = ext_len if ext_len is not None else 0
|
||||
|
||||
self.device = device
|
||||
|
||||
# Work out how cleanly we can divide the dataset into bsz parts.
|
||||
self.n_step = data.size(0) // bsz
|
||||
|
||||
# Trim off any extra elements that wouldn't cleanly fit (remainders).
|
||||
data = data.narrow(0, 0, self.n_step * bsz)
|
||||
|
||||
# Evenly divide the data across the bsz batches.
|
||||
self.data = data.view(bsz, -1).t().contiguous().to(device)
|
||||
|
||||
# Number of mini-batches
|
||||
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
|
||||
|
||||
def get_batch(self, i, bptt=None):
|
||||
if bptt is None: bptt = self.bptt
|
||||
seq_len = min(bptt, self.data.size(0) - 1 - i)
|
||||
|
||||
end_idx = i + seq_len
|
||||
beg_idx = max(0, i - self.ext_len)
|
||||
|
||||
data = self.data[beg_idx:end_idx]
|
||||
target = self.data[i+1:i+1+seq_len]
|
||||
|
||||
return data, target, seq_len
|
||||
|
||||
def get_fixlen_iter(self, start=0):
|
||||
for i in range(start, self.data.size(0) - 1, self.bptt):
|
||||
yield self.get_batch(i)
|
||||
|
||||
def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
|
||||
max_len = self.bptt + max_deviation * std
|
||||
i = start
|
||||
while True:
|
||||
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
|
||||
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
|
||||
data, target, seq_len = self.get_batch(i, bptt)
|
||||
i += seq_len
|
||||
yield data, target, seq_len
|
||||
if i >= self.data.size(0) - 2:
|
||||
break
|
||||
|
||||
def __iter__(self):
|
||||
return self.get_fixlen_iter()
|
||||
|
||||
|
||||
class LMShuffledIterator(object):
|
||||
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
|
||||
"""
|
||||
data -- list[LongTensor] -- there is no order among the LongTensors
|
||||
"""
|
||||
self.data = data
|
||||
|
||||
self.bsz = bsz
|
||||
self.bptt = bptt
|
||||
self.ext_len = ext_len if ext_len is not None else 0
|
||||
|
||||
self.device = device
|
||||
self.shuffle = shuffle
|
||||
|
||||
def get_sent_stream(self):
|
||||
# index iterator
|
||||
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
|
||||
else np.array(range(len(self.data)))
|
||||
|
||||
# sentence iterator
|
||||
for idx in epoch_indices:
|
||||
yield self.data[idx]
|
||||
|
||||
def stream_iterator(self, sent_stream):
|
||||
# streams for each data in the batch
|
||||
streams = [None] * self.bsz
|
||||
|
||||
data = torch.LongTensor(self.bptt, self.bsz)
|
||||
target = torch.LongTensor(self.bptt, self.bsz)
|
||||
|
||||
n_retain = 0
|
||||
|
||||
while True:
|
||||
# data : [n_retain+bptt x bsz]
|
||||
# target : [bptt x bsz]
|
||||
data[n_retain:].fill_(-1)
|
||||
target.fill_(-1)
|
||||
|
||||
valid_batch = True
|
||||
|
||||
for i in range(self.bsz):
|
||||
n_filled = 0
|
||||
try:
|
||||
while n_filled < self.bptt:
|
||||
if streams[i] is None or len(streams[i]) <= 1:
|
||||
streams[i] = next(sent_stream)
|
||||
# number of new tokens to fill in
|
||||
n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
|
||||
# first n_retain tokens are retained from last batch
|
||||
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
|
||||
streams[i][:n_new]
|
||||
target[n_filled:n_filled+n_new, i] = \
|
||||
streams[i][1:n_new+1]
|
||||
streams[i] = streams[i][n_new:]
|
||||
n_filled += n_new
|
||||
except StopIteration:
|
||||
valid_batch = False
|
||||
break
|
||||
|
||||
if not valid_batch:
|
||||
return
|
||||
|
||||
data = data.to(self.device)
|
||||
target = target.to(self.device)
|
||||
|
||||
yield data, target, self.bptt
|
||||
|
||||
n_retain = min(data.size(0), self.ext_len)
|
||||
if n_retain > 0:
|
||||
data[:n_retain] = data[-n_retain:]
|
||||
data.resize_(n_retain + self.bptt, data.size(1))
|
||||
|
||||
def __iter__(self):
|
||||
# sent_stream is an iterator
|
||||
sent_stream = self.get_sent_stream()
|
||||
|
||||
for batch in self.stream_iterator(sent_stream):
|
||||
yield batch
|
||||
|
||||
|
||||
class LMMultiFileIterator(LMShuffledIterator):
|
||||
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
|
||||
shuffle=False):
|
||||
|
||||
self.paths = paths
|
||||
self.vocab = vocab
|
||||
|
||||
self.bsz = bsz
|
||||
self.bptt = bptt
|
||||
self.ext_len = ext_len if ext_len is not None else 0
|
||||
|
||||
self.device = device
|
||||
self.shuffle = shuffle
|
||||
|
||||
def get_sent_stream(self, path):
|
||||
sents = self.vocab.encode_file(path, add_double_eos=True)
|
||||
if self.shuffle:
|
||||
np.random.shuffle(sents)
|
||||
sent_stream = iter(sents)
|
||||
|
||||
return sent_stream
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle:
|
||||
np.random.shuffle(self.paths)
|
||||
|
||||
for path in self.paths:
|
||||
# sent_stream is an iterator
|
||||
sent_stream = self.get_sent_stream(path)
|
||||
for batch in self.stream_iterator(sent_stream):
|
||||
yield batch
|
||||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, path, dataset, *args, **kwargs):
|
||||
self.dataset = dataset
|
||||
self.vocab = Vocab(*args, **kwargs)
|
||||
|
||||
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
|
||||
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||
self.vocab.count_file(os.path.join(path, 'valid.txt'))
|
||||
self.vocab.count_file(os.path.join(path, 'test.txt'))
|
||||
elif self.dataset == 'wt103':
|
||||
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||
elif self.dataset == 'lm1b':
|
||||
train_path_pattern = os.path.join(
|
||||
path, '1-billion-word-language-modeling-benchmark-r13output',
|
||||
'training-monolingual.tokenized.shuffled', 'news.en-*')
|
||||
train_paths = glob.glob(train_path_pattern)
|
||||
# the vocab will load from file when build_vocab() is called
|
||||
|
||||
self.vocab.build_vocab()
|
||||
|
||||
if self.dataset in ['ptb', 'wt2', 'wt103']:
|
||||
self.train = self.vocab.encode_file(
|
||||
os.path.join(path, 'train.txt'), ordered=True)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=True)
|
||||
elif self.dataset in ['enwik8', 'text8']:
|
||||
self.train = self.vocab.encode_file(
|
||||
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
|
||||
elif self.dataset == 'lm1b':
|
||||
self.train = train_paths
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
|
||||
|
||||
def get_iterator(self, split, *args, **kwargs):
|
||||
if split == 'train':
|
||||
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
|
||||
data_iter = LMOrderedIterator(self.train, *args, **kwargs)
|
||||
elif self.dataset == 'lm1b':
|
||||
kwargs['shuffle'] = True
|
||||
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
|
||||
elif split in ['valid', 'test']:
|
||||
data = self.valid if split == 'valid' else self.test
|
||||
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
|
||||
data_iter = LMOrderedIterator(data, *args, **kwargs)
|
||||
elif self.dataset == 'lm1b':
|
||||
data_iter = LMShuffledIterator(data, *args, **kwargs)
|
||||
|
||||
return data_iter
|
||||
|
||||
|
||||
def get_lm_corpus(datadir, dataset):
|
||||
fn = os.path.join(datadir, 'cache.pt')
|
||||
if os.path.exists(fn):
|
||||
print('Loading cached dataset...')
|
||||
corpus = torch.load(fn)
|
||||
else:
|
||||
print('Producing dataset {}...'.format(dataset))
|
||||
kwargs = {}
|
||||
if dataset in ['wt103', 'wt2']:
|
||||
kwargs['special'] = ['<eos>']
|
||||
kwargs['lower_case'] = False
|
||||
elif dataset == 'ptb':
|
||||
kwargs['special'] = ['<eos>']
|
||||
kwargs['lower_case'] = True
|
||||
elif dataset == 'lm1b':
|
||||
kwargs['special'] = []
|
||||
kwargs['lower_case'] = False
|
||||
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
|
||||
elif dataset in ['enwik8', 'text8']:
|
||||
pass
|
||||
|
||||
corpus = Corpus(datadir, dataset, **kwargs)
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='unit test')
|
||||
parser.add_argument('--datadir', type=str, default='../data/text8',
|
||||
help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='text8',
|
||||
choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
|
||||
help='dataset name')
|
||||
args = parser.parse_args()
|
||||
|
||||
corpus = get_lm_corpus(args.datadir, args.dataset)
|
||||
print('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
# coding: utf-8
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import os, sys
|
||||
|
||||
import torch
|
||||
|
||||
from data_utils import get_lm_corpus
|
||||
from mem_transformer import MemTransformerLM
|
||||
from utils.exp_utils import get_logger
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||
parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||
help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='wt103',
|
||||
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
|
||||
help='dataset name')
|
||||
parser.add_argument('--split', type=str, default='all',
|
||||
choices=['all', 'valid', 'test'],
|
||||
help='which split to evaluate')
|
||||
parser.add_argument('--batch_size', type=int, default=10,
|
||||
help='batch size')
|
||||
parser.add_argument('--tgt_len', type=int, default=5,
|
||||
help='number of tokens to predict')
|
||||
parser.add_argument('--ext_len', type=int, default=0,
|
||||
help='length of the extended context')
|
||||
parser.add_argument('--mem_len', type=int, default=0,
|
||||
help='length of the retained previous heads')
|
||||
parser.add_argument('--clamp_len', type=int, default=-1,
|
||||
help='max positional embedding index')
|
||||
parser.add_argument('--cuda', action='store_true',
|
||||
help='use CUDA')
|
||||
parser.add_argument('--work_dir', type=str, required=True,
|
||||
help='path to the work_dir')
|
||||
parser.add_argument('--no_log', action='store_true',
|
||||
help='do not log the eval result')
|
||||
parser.add_argument('--same_length', action='store_true',
|
||||
help='set same length attention with masking')
|
||||
args = parser.parse_args()
|
||||
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||
|
||||
device = torch.device("cuda" if args.cuda else "cpu")
|
||||
|
||||
# Get logger
|
||||
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
||||
log_=not args.no_log)
|
||||
|
||||
# Load dataset
|
||||
corpus = get_lm_corpus(args.data, args.dataset)
|
||||
ntokens = len(corpus.vocab)
|
||||
|
||||
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
|
||||
# Load the best saved model.
|
||||
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
|
||||
model = torch.load(f)
|
||||
model.backward_compatible()
|
||||
model = model.to(device)
|
||||
|
||||
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
||||
|
||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||
if args.clamp_len > 0:
|
||||
model.clamp_len = args.clamp_len
|
||||
if args.same_length:
|
||||
model.same_length = True
|
||||
|
||||
###############################################################################
|
||||
# Evaluation code
|
||||
###############################################################################
|
||||
def evaluate(eval_iter):
|
||||
# Turn on evaluation mode which disables dropout.
|
||||
model.eval()
|
||||
total_len, total_loss = 0, 0.
|
||||
start_time = time.time()
|
||||
with torch.no_grad():
|
||||
mems = tuple()
|
||||
for idx, (data, target, seq_len) in enumerate(eval_iter):
|
||||
ret = model(data, target, *mems)
|
||||
loss, mems = ret[0], ret[1:]
|
||||
loss = loss.mean()
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
total_time = time.time() - start_time
|
||||
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||
total_time, 1000 * total_time / (idx+1)))
|
||||
return total_loss / total_len
|
||||
|
||||
# Run on test data.
|
||||
if args.split == 'all':
|
||||
test_loss = evaluate(te_iter)
|
||||
valid_loss = evaluate(va_iter)
|
||||
elif args.split == 'valid':
|
||||
valid_loss = evaluate(va_iter)
|
||||
test_loss = None
|
||||
elif args.split == 'test':
|
||||
test_loss = evaluate(te_iter)
|
||||
valid_loss = None
|
||||
|
||||
def format_log(loss, split):
|
||||
if args.dataset in ['enwik8', 'text8']:
|
||||
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
|
||||
split, loss, loss / math.log(2))
|
||||
else:
|
||||
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
|
||||
split, loss, math.exp(loss))
|
||||
return log_str
|
||||
|
||||
log_str = ''
|
||||
if valid_loss is not None:
|
||||
log_str += format_log(valid_loss, 'valid')
|
||||
if test_loss is not None:
|
||||
log_str += format_log(test_loss, 'test')
|
||||
|
||||
logging('=' * 100)
|
||||
logging(log_str)
|
||||
logging('=' * 100)
|
||||
|
|
@ -1,812 +0,0 @@
|
|||
import sys
|
||||
import math
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
sys.path.append('utils')
|
||||
from proj_adaptive_softmax import ProjectedAdaptiveLogSoftmax
|
||||
from log_uniform_sampler import LogUniformSampler, sample_logits
|
||||
|
||||
class PositionalEmbedding(nn.Module):
|
||||
def __init__(self, demb):
|
||||
super(PositionalEmbedding, self).__init__()
|
||||
|
||||
self.demb = demb
|
||||
|
||||
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, pos_seq, bsz=None):
|
||||
sinusoid_inp = torch.ger(pos_seq, self.inv_freq)
|
||||
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
|
||||
|
||||
if bsz is not None:
|
||||
return pos_emb[:,None,:].expand(-1, bsz, -1)
|
||||
else:
|
||||
return pos_emb[:,None,:]
|
||||
|
||||
|
||||
class PositionwiseFF(nn.Module):
|
||||
def __init__(self, d_model, d_inner, dropout, pre_lnorm=False):
|
||||
super(PositionwiseFF, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.d_inner = d_inner
|
||||
self.dropout = dropout
|
||||
|
||||
self.CoreNet = nn.Sequential(
|
||||
nn.Linear(d_model, d_inner), nn.ReLU(inplace=True),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(d_inner, d_model),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self.pre_lnorm = pre_lnorm
|
||||
|
||||
def forward(self, inp):
|
||||
if self.pre_lnorm:
|
||||
##### layer normalization + positionwise feed-forward
|
||||
core_out = self.CoreNet(self.layer_norm(inp))
|
||||
|
||||
##### residual connection
|
||||
output = core_out + inp
|
||||
else:
|
||||
##### positionwise feed-forward
|
||||
core_out = self.CoreNet(inp)
|
||||
|
||||
##### residual connection + layer normalization
|
||||
output = self.layer_norm(inp + core_out)
|
||||
|
||||
return output
|
||||
|
||||
class MultiHeadAttn(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
||||
pre_lnorm=False):
|
||||
super(MultiHeadAttn, self).__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.d_model = d_model
|
||||
self.d_head = d_head
|
||||
self.dropout = dropout
|
||||
|
||||
self.q_net = nn.Linear(d_model, n_head * d_head, bias=False)
|
||||
self.kv_net = nn.Linear(d_model, 2 * n_head * d_head, bias=False)
|
||||
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.dropatt = nn.Dropout(dropatt)
|
||||
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self.scale = 1 / (d_head ** 0.5)
|
||||
|
||||
self.pre_lnorm = pre_lnorm
|
||||
|
||||
def forward(self, h, attn_mask=None, mems=None):
|
||||
##### multihead attention
|
||||
# [hlen x bsz x n_head x d_head]
|
||||
|
||||
if mems is not None:
|
||||
c = torch.cat([mems, h], 0)
|
||||
else:
|
||||
c = h
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### layer normalization
|
||||
c = self.layer_norm(c)
|
||||
|
||||
head_q = self.q_net(h)
|
||||
head_k, head_v = torch.chunk(self.kv_net(c), 2, -1)
|
||||
|
||||
head_q = head_q.view(h.size(0), h.size(1), self.n_head, self.d_head)
|
||||
head_k = head_k.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
||||
head_v = head_v.view(c.size(0), c.size(1), self.n_head, self.d_head)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_score = torch.einsum('ibnd,jbnd->ijbn', (head_q, head_k))
|
||||
attn_score.mul_(self.scale)
|
||||
if attn_mask is not None and attn_mask.any().item():
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = F.softmax(attn_score, dim=1)
|
||||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
# [qlen x klen x bsz x n_head] + [klen x bsz x n_head x d_head] -> [qlen x bsz x n_head x d_head]
|
||||
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, head_v))
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
##### linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
output = h + attn_out
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
output = self.layer_norm(h + attn_out)
|
||||
|
||||
return output
|
||||
|
||||
class RelMultiHeadAttn(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, dropout, dropatt=0,
|
||||
tgt_len=None, ext_len=None, mem_len=None, pre_lnorm=False):
|
||||
super(RelMultiHeadAttn, self).__init__()
|
||||
|
||||
self.n_head = n_head
|
||||
self.d_model = d_model
|
||||
self.d_head = d_head
|
||||
self.dropout = dropout
|
||||
|
||||
self.qkv_net = nn.Linear(d_model, 3 * n_head * d_head, bias=False)
|
||||
|
||||
self.drop = nn.Dropout(dropout)
|
||||
self.dropatt = nn.Dropout(dropatt)
|
||||
self.o_net = nn.Linear(n_head * d_head, d_model, bias=False)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self.scale = 1 / (d_head ** 0.5)
|
||||
|
||||
self.pre_lnorm = pre_lnorm
|
||||
|
||||
def _parallelogram_mask(self, h, w, left=False):
|
||||
mask = torch.ones((h, w)).byte()
|
||||
m = min(h, w)
|
||||
mask[:m,:m] = torch.triu(mask[:m,:m])
|
||||
mask[-m:,-m:] = torch.tril(mask[-m:,-m:])
|
||||
|
||||
if left:
|
||||
return mask
|
||||
else:
|
||||
return mask.flip(0)
|
||||
|
||||
def _shift(self, x, qlen, klen, mask, left=False):
|
||||
if qlen > 1:
|
||||
zero_pad = torch.zeros((x.size(0), qlen-1, x.size(2), x.size(3)),
|
||||
device=x.DEVICE, dtype=x.dtype)
|
||||
else:
|
||||
zero_pad = torch.zeros(0, device=x.DEVICE, dtype=x.dtype)
|
||||
|
||||
if left:
|
||||
mask = mask.flip(1)
|
||||
x_padded = torch.cat([zero_pad, x], dim=1).expand(qlen, -1, -1, -1)
|
||||
else:
|
||||
x_padded = torch.cat([x, zero_pad], dim=1).expand(qlen, -1, -1, -1)
|
||||
|
||||
x = x_padded.masked_select(mask[:,:,None,None]) \
|
||||
.view(qlen, klen, x.size(2), x.size(3))
|
||||
|
||||
return x
|
||||
|
||||
def _rel_shift(self, x, zero_triu=False):
|
||||
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
||||
device=x.DEVICE, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||
|
||||
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
||||
|
||||
x = x_padded[1:].view_as(x)
|
||||
|
||||
if zero_triu:
|
||||
ones = torch.ones((x.size(0), x.size(1)))
|
||||
x = x * torch.tril(ones, x.size(1) - x.size(0))[:,:,None,None]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, w, r, attn_mask=None, mems=None):
|
||||
raise NotImplementedError
|
||||
|
||||
class RelPartialLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RelPartialLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
|
||||
|
||||
self.r_net = nn.Linear(self.d_model, self.n_head * self.d_head, bias=False)
|
||||
|
||||
def forward(self, w, r, r_w_bias, r_r_bias, attn_mask=None, mems=None):
|
||||
qlen, rlen, bsz = w.size(0), r.size(0), w.size(1)
|
||||
|
||||
if mems is not None:
|
||||
cat = torch.cat([mems, w], 0)
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(cat))
|
||||
else:
|
||||
w_heads = self.qkv_net(cat)
|
||||
r_head_k = self.r_net(r)
|
||||
|
||||
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
||||
w_head_q = w_head_q[-qlen:]
|
||||
else:
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(w))
|
||||
else:
|
||||
w_heads = self.qkv_net(w)
|
||||
r_head_k = self.r_net(r)
|
||||
|
||||
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
||||
|
||||
klen = w_head_k.size(0)
|
||||
|
||||
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
||||
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
||||
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head) # qlen x bsz x n_head x d_head
|
||||
|
||||
r_head_k = r_head_k.view(rlen, self.n_head, self.d_head) # qlen x n_head x d_head
|
||||
|
||||
#### compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias # qlen x bsz x n_head x d_head
|
||||
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||
|
||||
rr_head_q = w_head_q + r_r_bias
|
||||
BD = torch.einsum('ibnd,jnd->ijbn', (rr_head_q, r_head_k)) # qlen x klen x bsz x n_head
|
||||
BD = self._rel_shift(BD)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_score = AC + BD
|
||||
attn_score.mul_(self.scale)
|
||||
|
||||
#### compute attention probability
|
||||
if attn_mask is not None and attn_mask.any().item():
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[None,:,:,None], -float('inf')).type_as(attn_score)
|
||||
elif attn_mask.dim() == 3:
|
||||
attn_score = attn_score.float().masked_fill(
|
||||
attn_mask[:,:,:,None], -float('inf')).type_as(attn_score)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = F.softmax(attn_score, dim=1)
|
||||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
#### compute attention vector
|
||||
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
##### linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
output = w + attn_out
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
output = self.layer_norm(w + attn_out)
|
||||
|
||||
return output
|
||||
|
||||
class RelLearnableMultiHeadAttn(RelMultiHeadAttn):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(RelLearnableMultiHeadAttn, self).__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, w, r_emb, r_w_bias, r_bias, attn_mask=None, mems=None):
|
||||
# r_emb: [klen, n_head, d_head], used for term B
|
||||
# r_w_bias: [n_head, d_head], used for term C
|
||||
# r_bias: [klen, n_head], used for term D
|
||||
|
||||
qlen, bsz = w.size(0), w.size(1)
|
||||
|
||||
if mems is not None:
|
||||
cat = torch.cat([mems, w], 0)
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(cat))
|
||||
else:
|
||||
w_heads = self.qkv_net(cat)
|
||||
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
||||
|
||||
w_head_q = w_head_q[-qlen:]
|
||||
else:
|
||||
if self.pre_lnorm:
|
||||
w_heads = self.qkv_net(self.layer_norm(w))
|
||||
else:
|
||||
w_heads = self.qkv_net(w)
|
||||
w_head_q, w_head_k, w_head_v = torch.chunk(w_heads, 3, dim=-1)
|
||||
|
||||
klen = w_head_k.size(0)
|
||||
|
||||
w_head_q = w_head_q.view(qlen, bsz, self.n_head, self.d_head)
|
||||
w_head_k = w_head_k.view(klen, bsz, self.n_head, self.d_head)
|
||||
w_head_v = w_head_v.view(klen, bsz, self.n_head, self.d_head)
|
||||
|
||||
if klen > r_emb.size(0):
|
||||
r_emb_pad = r_emb[0:1].expand(klen-r_emb.size(0), -1, -1)
|
||||
r_emb = torch.cat([r_emb_pad, r_emb], 0)
|
||||
r_bias_pad = r_bias[0:1].expand(klen-r_bias.size(0), -1)
|
||||
r_bias = torch.cat([r_bias_pad, r_bias], 0)
|
||||
else:
|
||||
r_emb = r_emb[-klen:]
|
||||
r_bias = r_bias[-klen:]
|
||||
|
||||
#### compute attention score
|
||||
rw_head_q = w_head_q + r_w_bias[None] # qlen x bsz x n_head x d_head
|
||||
|
||||
AC = torch.einsum('ibnd,jbnd->ijbn', (rw_head_q, w_head_k)) # qlen x klen x bsz x n_head
|
||||
B_ = torch.einsum('ibnd,jnd->ijbn', (w_head_q, r_emb)) # qlen x klen x bsz x n_head
|
||||
D_ = r_bias[None, :, None] # 1 x klen x 1 x n_head
|
||||
BD = self._rel_shift(B_ + D_)
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_score = AC + BD
|
||||
attn_score.mul_(self.scale)
|
||||
|
||||
#### compute attention probability
|
||||
if attn_mask is not None and attn_mask.any().item():
|
||||
if attn_mask.dim() == 2:
|
||||
attn_score.masked_fill_(attn_mask[None,:,:,None], -float('inf'))
|
||||
elif attn_mask.dim() == 3:
|
||||
attn_score.masked_fill_(attn_mask[:,:,:,None], -float('inf'))
|
||||
|
||||
# [qlen x klen x bsz x n_head]
|
||||
attn_prob = F.softmax(attn_score, dim=1)
|
||||
attn_prob = self.dropatt(attn_prob)
|
||||
|
||||
#### compute attention vector
|
||||
attn_vec = torch.einsum('ijbn,jbnd->ibnd', (attn_prob, w_head_v))
|
||||
|
||||
# [qlen x bsz x n_head x d_head]
|
||||
attn_vec = attn_vec.contiguous().view(
|
||||
attn_vec.size(0), attn_vec.size(1), self.n_head * self.d_head)
|
||||
|
||||
##### linear projection
|
||||
attn_out = self.o_net(attn_vec)
|
||||
attn_out = self.drop(attn_out)
|
||||
|
||||
if self.pre_lnorm:
|
||||
##### residual connection
|
||||
output = w + attn_out
|
||||
else:
|
||||
##### residual connection + layer normalization
|
||||
output = self.layer_norm(w + attn_out)
|
||||
|
||||
return output
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, d_inner, dropout, **kwargs):
|
||||
super(DecoderLayer, self).__init__()
|
||||
|
||||
self.dec_attn = MultiHeadAttn(n_head, d_model, d_head, dropout, **kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, dec_attn_mask=None, mems=None):
|
||||
|
||||
output = self.dec_attn(dec_inp, attn_mask=dec_attn_mask,
|
||||
mems=mems)
|
||||
output = self.pos_ff(output)
|
||||
|
||||
return output
|
||||
|
||||
class RelLearnableDecoderLayer(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
|
||||
**kwargs):
|
||||
super(RelLearnableDecoderLayer, self).__init__()
|
||||
|
||||
self.dec_attn = RelLearnableMultiHeadAttn(n_head, d_model, d_head, dropout,
|
||||
**kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, r_emb, r_w_bias, r_bias, dec_attn_mask=None, mems=None):
|
||||
|
||||
output = self.dec_attn(dec_inp, r_emb, r_w_bias, r_bias,
|
||||
attn_mask=dec_attn_mask,
|
||||
mems=mems)
|
||||
output = self.pos_ff(output)
|
||||
|
||||
return output
|
||||
|
||||
class RelPartialLearnableDecoderLayer(nn.Module):
|
||||
def __init__(self, n_head, d_model, d_head, d_inner, dropout,
|
||||
**kwargs):
|
||||
super(RelPartialLearnableDecoderLayer, self).__init__()
|
||||
|
||||
self.dec_attn = RelPartialLearnableMultiHeadAttn(n_head, d_model,
|
||||
d_head, dropout, **kwargs)
|
||||
self.pos_ff = PositionwiseFF(d_model, d_inner, dropout,
|
||||
pre_lnorm=kwargs.get('pre_lnorm'))
|
||||
|
||||
def forward(self, dec_inp, r, r_w_bias, r_r_bias, dec_attn_mask=None, mems=None):
|
||||
|
||||
output = self.dec_attn(dec_inp, r, r_w_bias, r_r_bias,
|
||||
attn_mask=dec_attn_mask,
|
||||
mems=mems)
|
||||
output = self.pos_ff(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class AdaptiveEmbedding(nn.Module):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
sample_softmax=False):
|
||||
super(AdaptiveEmbedding, self).__init__()
|
||||
|
||||
self.n_token = n_token
|
||||
self.d_embed = d_embed
|
||||
|
||||
self.cutoffs = cutoffs + [n_token]
|
||||
self.div_val = div_val
|
||||
self.d_proj = d_proj
|
||||
|
||||
self.emb_scale = d_proj ** 0.5
|
||||
|
||||
self.cutoff_ends = [0] + self.cutoffs
|
||||
|
||||
self.emb_layers = nn.ModuleList()
|
||||
self.emb_projs = nn.ParameterList()
|
||||
if div_val == 1:
|
||||
self.emb_layers.append(
|
||||
nn.Embedding(n_token, d_embed, sparse=sample_softmax>0)
|
||||
)
|
||||
if d_proj != d_embed:
|
||||
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_embed)))
|
||||
else:
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
||||
d_emb_i = d_embed // (div_val ** i)
|
||||
self.emb_layers.append(nn.Embedding(r_idx-l_idx, d_emb_i))
|
||||
self.emb_projs.append(nn.Parameter(torch.Tensor(d_proj, d_emb_i)))
|
||||
|
||||
def forward(self, inp):
|
||||
if self.div_val == 1:
|
||||
embed = self.emb_layers[0](inp)
|
||||
if self.d_proj != self.d_embed:
|
||||
embed = F.linear(embed, self.emb_projs[0])
|
||||
else:
|
||||
param = next(self.parameters())
|
||||
inp_flat = inp.view(-1)
|
||||
emb_flat = torch.zeros([inp_flat.size(0), self.d_proj],
|
||||
dtype=param.dtype, device=param.device)
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
|
||||
mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx)
|
||||
indices_i = mask_i.nonzero().squeeze()
|
||||
|
||||
if indices_i.numel() == 0:
|
||||
continue
|
||||
|
||||
inp_i = inp_flat.index_select(0, indices_i) - l_idx
|
||||
emb_i = self.emb_layers[i](inp_i)
|
||||
emb_i = F.linear(emb_i, self.emb_projs[i])
|
||||
|
||||
emb_flat.index_copy_(0, indices_i, emb_i)
|
||||
|
||||
embed = emb_flat.view(*inp.size(), self.d_proj)
|
||||
|
||||
embed.mul_(self.emb_scale)
|
||||
|
||||
return embed
|
||||
|
||||
class MemTransformerLM(nn.Module):
|
||||
def __init__(self, n_token, n_layer, n_head, d_model, d_head, d_inner,
|
||||
dropout, dropatt, tie_weight=True, d_embed=None,
|
||||
div_val=1, tie_projs=[False], pre_lnorm=False,
|
||||
tgt_len=None, ext_len=None, mem_len=None,
|
||||
cutoffs=[], adapt_inp=False,
|
||||
same_length=False, attn_type=0, clamp_len=-1,
|
||||
sample_softmax=-1):
|
||||
super(MemTransformerLM, self).__init__()
|
||||
self.n_token = n_token
|
||||
|
||||
d_embed = d_model if d_embed is None else d_embed
|
||||
self.d_embed = d_embed
|
||||
self.d_model = d_model
|
||||
self.n_head = n_head
|
||||
self.d_head = d_head
|
||||
|
||||
self.word_emb = AdaptiveEmbedding(n_token, d_embed, d_model, cutoffs,
|
||||
div_val=div_val)
|
||||
|
||||
self.drop = nn.Dropout(dropout)
|
||||
|
||||
self.n_layer = n_layer
|
||||
|
||||
self.tgt_len = tgt_len
|
||||
self.mem_len = mem_len
|
||||
self.ext_len = ext_len
|
||||
self.max_klen = tgt_len + ext_len + mem_len
|
||||
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
if attn_type == 0: # the default attention
|
||||
for i in range(n_layer):
|
||||
self.layers.append(
|
||||
RelPartialLearnableDecoderLayer(
|
||||
n_head, d_model, d_head, d_inner, dropout,
|
||||
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
||||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||
)
|
||||
elif attn_type == 1: # learnable embeddings
|
||||
for i in range(n_layer):
|
||||
self.layers.append(
|
||||
RelLearnableDecoderLayer(
|
||||
n_head, d_model, d_head, d_inner, dropout,
|
||||
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
||||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||
)
|
||||
elif attn_type in [2, 3]: # absolute embeddings
|
||||
for i in range(n_layer):
|
||||
self.layers.append(
|
||||
DecoderLayer(
|
||||
n_head, d_model, d_head, d_inner, dropout,
|
||||
dropatt=dropatt, pre_lnorm=pre_lnorm)
|
||||
)
|
||||
|
||||
self.sample_softmax = sample_softmax
|
||||
# use sampled softmax
|
||||
if sample_softmax > 0:
|
||||
self.out_layer = nn.Linear(d_model, n_token)
|
||||
if tie_weight:
|
||||
self.out_layer.weight = self.word_emb.weight
|
||||
self.tie_weight = tie_weight
|
||||
self.sampler = LogUniformSampler(n_token, sample_softmax)
|
||||
|
||||
# use adaptive softmax (including standard softmax)
|
||||
else:
|
||||
self.crit = ProjectedAdaptiveLogSoftmax(n_token, d_embed, d_model,
|
||||
cutoffs, div_val=div_val)
|
||||
|
||||
if tie_weight:
|
||||
for i in range(len(self.crit.out_layers)):
|
||||
self.crit.out_layers[i].weight = self.word_emb.emb_layers[i].weight
|
||||
|
||||
if tie_projs:
|
||||
for i, tie_proj in enumerate(tie_projs):
|
||||
if tie_proj and div_val == 1 and d_model != d_embed:
|
||||
self.crit.out_projs[i] = self.word_emb.emb_projs[0]
|
||||
elif tie_proj and div_val != 1:
|
||||
self.crit.out_projs[i] = self.word_emb.emb_projs[i]
|
||||
|
||||
self.same_length = same_length
|
||||
self.clamp_len = clamp_len
|
||||
|
||||
self._create_params()
|
||||
|
||||
def backward_compatible(self):
|
||||
self.sample_softmax = -1
|
||||
|
||||
def _create_params(self):
|
||||
if self.attn_type == 0: # default attention
|
||||
self.pos_emb = PositionalEmbedding(self.d_model)
|
||||
self.r_w_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
||||
self.r_r_bias = nn.Parameter(torch.Tensor(self.n_head, self.d_head))
|
||||
elif self.attn_type == 1: # learnable
|
||||
self.r_emb = nn.Parameter(torch.Tensor(
|
||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||
self.r_w_bias = nn.Parameter(torch.Tensor(
|
||||
self.n_layer, self.n_head, self.d_head))
|
||||
self.r_bias = nn.Parameter(torch.Tensor(
|
||||
self.n_layer, self.max_klen, self.n_head))
|
||||
elif self.attn_type == 2: # absolute standard
|
||||
self.pos_emb = PositionalEmbedding(self.d_model)
|
||||
elif self.attn_type == 3: # absolute deeper SA
|
||||
self.r_emb = nn.Parameter(torch.Tensor(
|
||||
self.n_layer, self.max_klen, self.n_head, self.d_head))
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.tgt_len = tgt_len
|
||||
self.mem_len = mem_len
|
||||
self.ext_len = ext_len
|
||||
|
||||
def init_mems(self):
|
||||
if self.mem_len > 0:
|
||||
mems = []
|
||||
param = next(self.parameters())
|
||||
for i in range(self.n_layer+1):
|
||||
empty = torch.empty(0, dtype=param.dtype, device=param.device)
|
||||
mems.append(empty)
|
||||
|
||||
return mems
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_mems(self, hids, mems, qlen, mlen):
|
||||
# does not deal with None
|
||||
if mems is None: return None
|
||||
|
||||
# mems is not None
|
||||
assert len(hids) == len(mems), 'len(hids) != len(mems)'
|
||||
|
||||
# There are `mlen + qlen` steps that can be cached into mems
|
||||
# For the next step, the last `ext_len` of the `qlen` tokens
|
||||
# will be used as the extended context. Hence, we only cache
|
||||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
||||
# to `mlen + qlen - self.ext_len`.
|
||||
with torch.no_grad():
|
||||
new_mems = []
|
||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
for i in range(len(hids)):
|
||||
|
||||
cat = torch.cat([mems[i], hids[i]], dim=0)
|
||||
new_mems.append(cat[beg_idx:end_idx].detach())
|
||||
|
||||
return new_mems
|
||||
|
||||
def _forward(self, dec_inp, mems=None):
|
||||
qlen, bsz = dec_inp.size()
|
||||
|
||||
word_emb = self.word_emb(dec_inp)
|
||||
|
||||
mlen = mems[0].size(0) if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
if self.same_length:
|
||||
all_ones = word_emb.new_ones(qlen, klen)
|
||||
mask_len = klen - self.mem_len
|
||||
if mask_len > 0:
|
||||
mask_shift_len = qlen - mask_len
|
||||
else:
|
||||
mask_shift_len = qlen
|
||||
dec_attn_mask = (torch.triu(all_ones, 1+mlen)
|
||||
+ torch.tril(all_ones, -mask_shift_len)).byte()[:, :, None] # -1
|
||||
else:
|
||||
dec_attn_mask = torch.triu(
|
||||
word_emb.new_ones(qlen, klen), diagonal=1+mlen).byte()[:,:,None]
|
||||
|
||||
hids = []
|
||||
if self.attn_type == 0: # default
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
pos_seq.clamp_(max=self.clamp_len)
|
||||
pos_emb = self.pos_emb(pos_seq)
|
||||
|
||||
core_out = self.drop(word_emb)
|
||||
pos_emb = self.drop(pos_emb)
|
||||
|
||||
hids.append(core_out)
|
||||
for i, layer in enumerate(self.layers):
|
||||
mems_i = None if mems is None else mems[i]
|
||||
core_out = layer(core_out, pos_emb, self.r_w_bias,
|
||||
self.r_r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
||||
hids.append(core_out)
|
||||
elif self.attn_type == 1: # learnable
|
||||
core_out = self.drop(word_emb)
|
||||
hids.append(core_out)
|
||||
for i, layer in enumerate(self.layers):
|
||||
if self.clamp_len > 0:
|
||||
r_emb = self.r_emb[i][-self.clamp_len :]
|
||||
r_bias = self.r_bias[i][-self.clamp_len :]
|
||||
else:
|
||||
r_emb, r_bias = self.r_emb[i], self.r_bias[i]
|
||||
|
||||
mems_i = None if mems is None else mems[i]
|
||||
core_out = layer(core_out, r_emb, self.r_w_bias[i],
|
||||
r_bias, dec_attn_mask=dec_attn_mask, mems=mems_i)
|
||||
hids.append(core_out)
|
||||
elif self.attn_type == 2: # absolute
|
||||
pos_seq = torch.arange(klen - 1, -1, -1.0, device=word_emb.DEVICE,
|
||||
dtype=word_emb.dtype)
|
||||
if self.clamp_len > 0:
|
||||
pos_seq.clamp_(max=self.clamp_len)
|
||||
pos_emb = self.pos_emb(pos_seq)
|
||||
|
||||
core_out = self.drop(word_emb + pos_emb[-qlen:])
|
||||
|
||||
hids.append(core_out)
|
||||
for i, layer in enumerate(self.layers):
|
||||
mems_i = None if mems is None else mems[i]
|
||||
if mems_i is not None and i == 0:
|
||||
mems_i += pos_emb[:mlen]
|
||||
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i)
|
||||
hids.append(core_out)
|
||||
elif self.attn_type == 3:
|
||||
core_out = self.drop(word_emb)
|
||||
|
||||
hids.append(core_out)
|
||||
for i, layer in enumerate(self.layers):
|
||||
mems_i = None if mems is None else mems[i]
|
||||
if mems_i is not None and mlen > 0:
|
||||
cur_emb = self.r_emb[i][:-qlen]
|
||||
cur_size = cur_emb.size(0)
|
||||
if cur_size < mlen:
|
||||
cur_emb_pad = cur_emb[0:1].expand(mlen-cur_size, -1, -1)
|
||||
cur_emb = torch.cat([cur_emb_pad, cur_emb], 0)
|
||||
else:
|
||||
cur_emb = cur_emb[-mlen:]
|
||||
mems_i += cur_emb.view(mlen, 1, -1)
|
||||
core_out += self.r_emb[i][-qlen:].view(qlen, 1, -1)
|
||||
|
||||
core_out = layer(core_out, dec_attn_mask=dec_attn_mask,
|
||||
mems=mems_i)
|
||||
hids.append(core_out)
|
||||
|
||||
core_out = self.drop(core_out)
|
||||
|
||||
new_mems = self._update_mems(hids, mems, mlen, qlen)
|
||||
|
||||
return core_out, new_mems
|
||||
|
||||
def forward(self, data, target, *mems):
|
||||
# nn.DataParallel does not allow size(0) tensors to be broadcasted.
|
||||
# So, have to initialize size(0) mems inside the model forward.
|
||||
# Moreover, have to return new_mems to allow nn.DataParallel to piece
|
||||
# them together.
|
||||
if not mems: mems = self.init_mems()
|
||||
|
||||
tgt_len = target.size(0)
|
||||
hidden, new_mems = self._forward(data, mems=mems)
|
||||
|
||||
pred_hid = hidden[-tgt_len:]
|
||||
if self.sample_softmax > 0 and self.training:
|
||||
assert self.tie_weight
|
||||
logit = sample_logits(self.word_emb,
|
||||
self.out_layer.bias, target, pred_hid, self.sampler)
|
||||
loss = -F.log_softmax(logit, -1)[:, :, 0]
|
||||
else:
|
||||
loss = self.crit(pred_hid.view(-1, pred_hid.size(-1)), target.view(-1))
|
||||
loss = loss.view(tgt_len, -1)
|
||||
|
||||
if new_mems is None:
|
||||
return [loss]
|
||||
else:
|
||||
return [loss] + new_mems
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='unit test')
|
||||
|
||||
parser.add_argument('--n_layer', type=int, default=4, help='')
|
||||
parser.add_argument('--n_rel_layer', type=int, default=4, help='')
|
||||
parser.add_argument('--n_head', type=int, default=2, help='')
|
||||
parser.add_argument('--d_head', type=int, default=2, help='')
|
||||
parser.add_argument('--d_model', type=int, default=200, help='')
|
||||
parser.add_argument('--d_embed', type=int, default=200, help='')
|
||||
parser.add_argument('--d_inner', type=int, default=200, help='')
|
||||
parser.add_argument('--dropout', type=float, default=0.0, help='')
|
||||
parser.add_argument('--cuda', action='store_true', help='')
|
||||
parser.add_argument('--seed', type=int, default=1111, help='')
|
||||
parser.add_argument('--multi_gpu', action='store_true', help='')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
device = torch.device("cuda" if args.cuda else "cpu")
|
||||
|
||||
B = 4
|
||||
tgt_len, mem_len, ext_len = 36, 36, 0
|
||||
data_len = tgt_len * 20
|
||||
args.n_token = 10000
|
||||
|
||||
import data_utils
|
||||
|
||||
data = torch.LongTensor(data_len*B).random_(0, args.n_token).to(device)
|
||||
diter = data_utils.LMOrderedIterator(data, B, tgt_len, device=device, ext_len=ext_len)
|
||||
|
||||
cutoffs = [args.n_token // 2]
|
||||
tie_projs = [False] + [True] * len(cutoffs)
|
||||
|
||||
for div_val in [1, 2]:
|
||||
for d_embed in [200, 100]:
|
||||
model = MemTransformerLM(args.n_token, args.n_layer, args.n_head,
|
||||
args.d_model, args.d_head, args.d_inner, args.dropout,
|
||||
dropatt=args.dropout, tie_weight=True,
|
||||
d_embed=d_embed, div_val=div_val,
|
||||
tie_projs=tie_projs, pre_lnorm=True,
|
||||
tgt_len=tgt_len, ext_len=ext_len, mem_len=mem_len,
|
||||
cutoffs=cutoffs, attn_type=0).to(device)
|
||||
|
||||
print(sum(p.numel() for p in model.parameters()))
|
||||
|
||||
mems = tuple()
|
||||
for idx, (inp, tgt, seqlen) in enumerate(diter):
|
||||
print('batch {}'.format(idx))
|
||||
out = model(inp, tgt, *mems)
|
||||
mems = out[1:]
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/enwik8/ \
|
||||
--dataset enwik8 \
|
||||
--n_layer 12 \
|
||||
--d_model 512 \
|
||||
--n_head 8 \
|
||||
--d_head 64 \
|
||||
--d_inner 2048 \
|
||||
--dropout 0.1 \
|
||||
--dropatt 0.0 \
|
||||
--optim adam \
|
||||
--lr 0.00025 \
|
||||
--warmup_step 0 \
|
||||
--max_step 400000 \
|
||||
--tgt_len 512 \
|
||||
--mem_len 512 \
|
||||
--eval_tgt_len 128 \
|
||||
--batch_size 22 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 4 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/enwik8/ \
|
||||
--dataset enwik8 \
|
||||
--tgt_len 80 \
|
||||
--mem_len 2100 \
|
||||
--clamp_len 820 \
|
||||
--same_length \
|
||||
--split test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/enwik8/ \
|
||||
--dataset enwik8 \
|
||||
--n_layer 24 \
|
||||
--d_model 1024 \
|
||||
--n_head 8 \
|
||||
--d_head 128 \
|
||||
--d_inner 3072 \
|
||||
--dropout 0.15 \
|
||||
--dropatt 0.15 \
|
||||
--optim adam \
|
||||
--lr 0.00025 \
|
||||
--warmup_step 4000 \
|
||||
--max_step 400000 \
|
||||
--tgt_len 768 \
|
||||
--mem_len 768 \
|
||||
--eval_tgt_len 128 \
|
||||
--batch_size 64 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 0 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/enwik8/ \
|
||||
--dataset enwik8 \
|
||||
--tgt_len 128 \
|
||||
--mem_len 3800 \
|
||||
--clamp_len 1000 \
|
||||
--same_length \
|
||||
--split test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/one-billion-words/ \
|
||||
--dataset lm1b \
|
||||
--adaptive \
|
||||
--n_layer 18 \
|
||||
--d_model 1024 \
|
||||
--div_val 4 \
|
||||
--n_head 8 \
|
||||
--d_head 128 \
|
||||
--d_inner 4096 \
|
||||
--dropout 0.0 \
|
||||
--dropatt 0.0 \
|
||||
--optim adam \
|
||||
--warmup_step 20000 \
|
||||
--max_step 500000 \
|
||||
--lr 0.00025 \
|
||||
--tgt_len 32 \
|
||||
--mem_len 32 \
|
||||
--eval_tgt_len 32 \
|
||||
--batch_size 224 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 32 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/one-billion-words/ \
|
||||
--dataset lm1b \
|
||||
--batch_size 64 \
|
||||
--tgt_len 32 \
|
||||
--mem_len 128 \
|
||||
--split test \
|
||||
--same_length \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/one-billion-words/ \
|
||||
--dataset lm1b \
|
||||
--adaptive \
|
||||
--div_val 4 \
|
||||
--n_layer 24 \
|
||||
--d_model 1280 \
|
||||
--n_head 16 \
|
||||
--d_head 80 \
|
||||
--d_inner 8192 \
|
||||
--dropout 0.05 \
|
||||
--dropatt 0.05 \
|
||||
--optim adam \
|
||||
--warmup_step 30000 \
|
||||
--max_step 1200000 \
|
||||
--lr 0.00025 \
|
||||
--tgt_len 32 \
|
||||
--mem_len 32 \
|
||||
--eval_tgt_len 32 \
|
||||
--batch_size 512 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 0 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/one-billion-words/ \
|
||||
--dataset lm1b \
|
||||
--batch_size 8 \
|
||||
--tgt_len 32 \
|
||||
--mem_len 128 \
|
||||
--split test \
|
||||
--same_length \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/text8/ \
|
||||
--dataset text8 \
|
||||
--n_layer 12 \
|
||||
--d_model 512 \
|
||||
--n_head 8 \
|
||||
--d_head 64 \
|
||||
--d_inner 2048 \
|
||||
--dropout 0.1 \
|
||||
--dropatt 0.0 \
|
||||
--optim adam \
|
||||
--lr 0.00025 \
|
||||
--warmup_step 0 \
|
||||
--max_step 400000 \
|
||||
--tgt_len 512 \
|
||||
--mem_len 512 \
|
||||
--eval_tgt_len 128 \
|
||||
--batch_size 22 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 4 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/text8/ \
|
||||
--dataset text8 \
|
||||
--tgt_len 80 \
|
||||
--mem_len 2100 \
|
||||
--clamp_len 820 \
|
||||
--same_length \
|
||||
--split test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,38 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/text8/ \
|
||||
--dataset text8 \
|
||||
--n_layer 24 \
|
||||
--d_model 1024 \
|
||||
--n_head 8 \
|
||||
--d_head 128 \
|
||||
--d_inner 3072 \
|
||||
--dropout 0.15 \
|
||||
--dropatt 0.15 \
|
||||
--optim adam \
|
||||
--lr 0.00025 \
|
||||
--tgt_len 768 \
|
||||
--mem_len 768 \
|
||||
--eval_tgt_len 128 \
|
||||
--batch_size 64 \
|
||||
--max_step 400000 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/text8/ \
|
||||
--dataset text8 \
|
||||
--tgt_len 128 \
|
||||
--mem_len 3800 \
|
||||
--clamp_len 1000 \
|
||||
--same_length \
|
||||
--split test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/wikitext-103/ \
|
||||
--dataset wt103 \
|
||||
--adaptive \
|
||||
--n_layer 16 \
|
||||
--d_model 410 \
|
||||
--n_head 10 \
|
||||
--d_head 41 \
|
||||
--d_inner 2100 \
|
||||
--dropout 0.1 \
|
||||
--dropatt 0.0 \
|
||||
--optim adam \
|
||||
--lr 0.00025 \
|
||||
--warmup_step 0 \
|
||||
--max_step 200000 \
|
||||
--tgt_len 150 \
|
||||
--mem_len 150 \
|
||||
--eval_tgt_len 150 \
|
||||
--batch_size 60 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 4 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/wikitext-103/ \
|
||||
--dataset wt103 \
|
||||
--tgt_len 64 \
|
||||
--mem_len 640 \
|
||||
--clamp_len 400 \
|
||||
--same_length \
|
||||
--split test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
if [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--cuda \
|
||||
--data ../data/wikitext-103/ \
|
||||
--dataset wt103 \
|
||||
--adaptive \
|
||||
--div_val 4 \
|
||||
--n_layer 18 \
|
||||
--d_model 1024 \
|
||||
--n_head 16 \
|
||||
--d_head 64 \
|
||||
--d_inner 4096 \
|
||||
--dropout 0.2 \
|
||||
--dropatt 0.2 \
|
||||
--optim adam \
|
||||
--lr 0.00025 \
|
||||
--warmup_step 16000 \
|
||||
--max_step 4000000 \
|
||||
--tgt_len 384 \
|
||||
--mem_len 384 \
|
||||
--eval_tgt_len 128 \
|
||||
--batch_size 128 \
|
||||
--multi_gpu \
|
||||
--gpu0_bsz 0 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python eval.py \
|
||||
--cuda \
|
||||
--data ../data/wikitext-103/ \
|
||||
--dataset wt103 \
|
||||
--tgt_len 128 \
|
||||
--mem_len 1600 \
|
||||
--clamp_len 1000 \
|
||||
--same_length \
|
||||
--split test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,562 +0,0 @@
|
|||
# coding: utf-8
|
||||
import argparse
|
||||
import time
|
||||
import math
|
||||
import os, sys
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from data_utils import get_lm_corpus
|
||||
from mem_transformer import MemTransformerLM
|
||||
from utils.exp_utils import create_exp_dir
|
||||
from utils.data_parallel import BalancedDataParallel
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||
parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||
help='location of the data corpus')
|
||||
parser.add_argument('--dataset', type=str, default='wt103',
|
||||
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
|
||||
help='dataset name')
|
||||
parser.add_argument('--n_layer', type=int, default=12,
|
||||
help='number of total layers')
|
||||
parser.add_argument('--n_head', type=int, default=10,
|
||||
help='number of heads')
|
||||
parser.add_argument('--d_head', type=int, default=50,
|
||||
help='head dimension')
|
||||
parser.add_argument('--d_embed', type=int, default=-1,
|
||||
help='embedding dimension')
|
||||
parser.add_argument('--d_model', type=int, default=500,
|
||||
help='model dimension')
|
||||
parser.add_argument('--d_inner', type=int, default=1000,
|
||||
help='inner dimension in FF')
|
||||
parser.add_argument('--dropout', type=float, default=0.0,
|
||||
help='global dropout rate')
|
||||
parser.add_argument('--dropatt', type=float, default=0.0,
|
||||
help='attention probability dropout rate')
|
||||
parser.add_argument('--init', default='normal', type=str,
|
||||
help='parameter initializer to use.')
|
||||
parser.add_argument('--emb_init', default='normal', type=str,
|
||||
help='parameter initializer to use.')
|
||||
parser.add_argument('--init_range', type=float, default=0.1,
|
||||
help='parameters initialized by U(-init_range, init_range)')
|
||||
parser.add_argument('--emb_init_range', type=float, default=0.01,
|
||||
help='parameters initialized by U(-init_range, init_range)')
|
||||
parser.add_argument('--init_std', type=float, default=0.02,
|
||||
help='parameters initialized by N(0, init_std)')
|
||||
parser.add_argument('--proj_init_std', type=float, default=0.01,
|
||||
help='parameters initialized by N(0, init_std)')
|
||||
parser.add_argument('--optim', default='adam', type=str,
|
||||
choices=['adam', 'sgd', 'adagrad'],
|
||||
help='optimizer to use.')
|
||||
parser.add_argument('--lr', type=float, default=0.00025,
|
||||
help='initial learning rate (0.00025|5 for adam|sgd)')
|
||||
parser.add_argument('--mom', type=float, default=0.0,
|
||||
help='momentum for sgd')
|
||||
parser.add_argument('--scheduler', default='cosine', type=str,
|
||||
choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
|
||||
help='lr scheduler to use.')
|
||||
parser.add_argument('--warmup_step', type=int, default=0,
|
||||
help='upper epoch limit')
|
||||
parser.add_argument('--decay_rate', type=float, default=0.5,
|
||||
help='decay factor when ReduceLROnPlateau is used')
|
||||
parser.add_argument('--lr_min', type=float, default=0.0,
|
||||
help='minimum learning rate during annealing')
|
||||
parser.add_argument('--clip', type=float, default=0.25,
|
||||
help='gradient clipping')
|
||||
parser.add_argument('--clip_nonemb', action='store_true',
|
||||
help='only clip the gradient of non-embedding params')
|
||||
parser.add_argument('--max_step', type=int, default=100000,
|
||||
help='upper epoch limit')
|
||||
parser.add_argument('--batch_size', type=int, default=60,
|
||||
help='batch size')
|
||||
parser.add_argument('--batch_chunk', type=int, default=1,
|
||||
help='split batch into chunks to save memory')
|
||||
parser.add_argument('--tgt_len', type=int, default=70,
|
||||
help='number of tokens to predict')
|
||||
parser.add_argument('--eval_tgt_len', type=int, default=50,
|
||||
help='number of tokens to predict for evaluation')
|
||||
parser.add_argument('--ext_len', type=int, default=0,
|
||||
help='length of the extended context')
|
||||
parser.add_argument('--mem_len', type=int, default=0,
|
||||
help='length of the retained previous heads')
|
||||
parser.add_argument('--not_tied', action='store_true',
|
||||
help='do not tie the word embedding and softmax weights')
|
||||
parser.add_argument('--seed', type=int, default=1111,
|
||||
help='random seed')
|
||||
parser.add_argument('--cuda', action='store_true',
|
||||
help='use CUDA')
|
||||
parser.add_argument('--adaptive', action='store_true',
|
||||
help='use adaptive softmax')
|
||||
parser.add_argument('--div_val', type=int, default=1,
|
||||
help='divident value for adapative input and softmax')
|
||||
parser.add_argument('--pre_lnorm', action='store_true',
|
||||
help='apply LayerNorm to the input instead of the output')
|
||||
parser.add_argument('--varlen', action='store_true',
|
||||
help='use variable length')
|
||||
parser.add_argument('--multi_gpu', action='store_true',
|
||||
help='use multiple GPU')
|
||||
parser.add_argument('--log-interval', type=int, default=200,
|
||||
help='report interval')
|
||||
parser.add_argument('--eval-interval', type=int, default=4000,
|
||||
help='evaluation interval')
|
||||
parser.add_argument('--work_dir', default='LM-TFM', type=str,
|
||||
help='experiment directory.')
|
||||
parser.add_argument('--restart', action='store_true',
|
||||
help='restart training from the saved checkpoint')
|
||||
parser.add_argument('--restart_dir', type=str, default='',
|
||||
help='restart dir')
|
||||
parser.add_argument('--debug', action='store_true',
|
||||
help='run in debug mode (do not create exp dir)')
|
||||
parser.add_argument('--same_length', action='store_true',
|
||||
help='use the same attn length for all tokens')
|
||||
parser.add_argument('--attn_type', type=int, default=0,
|
||||
help='attention type. 0 for ours, 1 for Shaw et al,'
|
||||
'2 for Vaswani et al, 3 for Al Rfou et al.')
|
||||
parser.add_argument('--clamp_len', type=int, default=-1,
|
||||
help='use the same pos embeddings after clamp_len')
|
||||
parser.add_argument('--eta_min', type=float, default=0.0,
|
||||
help='min learning rate for cosine scheduler')
|
||||
parser.add_argument('--gpu0_bsz', type=int, default=-1,
|
||||
help='batch size on gpu 0')
|
||||
parser.add_argument('--max_eval_steps', type=int, default=-1,
|
||||
help='max eval steps')
|
||||
parser.add_argument('--sample_softmax', type=int, default=-1,
|
||||
help='number of samples in sampled softmax')
|
||||
parser.add_argument('--patience', type=int, default=0,
|
||||
help='patience')
|
||||
parser.add_argument('--finetune_v2', action='store_true',
|
||||
help='finetune v2')
|
||||
parser.add_argument('--finetune_v3', action='store_true',
|
||||
help='finetune v3')
|
||||
parser.add_argument('--fp16', action='store_true',
|
||||
help='Run in pseudo-fp16 mode (fp16 storage fp32 math).')
|
||||
parser.add_argument('--static-loss-scale', type=float, default=1,
|
||||
help='Static loss scale, positive power of 2 values can '
|
||||
'improve fp16 convergence.')
|
||||
parser.add_argument('--dynamic-loss-scale', action='store_true',
|
||||
help='Use dynamic loss scaling. If supplied, this argument'
|
||||
' supersedes --static-loss-scale.')
|
||||
args = parser.parse_args()
|
||||
args.tied = not args.not_tied
|
||||
|
||||
if args.d_embed < 0:
|
||||
args.d_embed = args.d_model
|
||||
|
||||
assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||
assert args.batch_size % args.batch_chunk == 0
|
||||
|
||||
args.work_dir = '{}-{}'.format(args.work_dir, args.dataset)
|
||||
args.work_dir = os.path.join(args.work_dir, time.strftime('%Y%m%d-%H%M%S'))
|
||||
logging = create_exp_dir(args.work_dir,
|
||||
scripts_to_save=['train.py', 'mem_transformer.py'], debug=args.debug)
|
||||
|
||||
# Set the random seed manually for reproducibility.
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if torch.cuda.is_available():
|
||||
if not args.cuda:
|
||||
print('WARNING: You have a CUDA DEVICE, so you should probably run with --cuda')
|
||||
else:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
# Validate `--fp16` option
|
||||
if args.fp16:
|
||||
if not args.cuda:
|
||||
print('WARNING: --fp16 requires --cuda, ignoring --fp16 option')
|
||||
args.fp16 = False
|
||||
else:
|
||||
try:
|
||||
from apex.fp16_utils import FP16_Optimizer
|
||||
except:
|
||||
print('WARNING: apex not installed, ignoring --fp16 option')
|
||||
args.fp16 = False
|
||||
|
||||
device = torch.device('cuda' if args.cuda else 'cpu')
|
||||
|
||||
###############################################################################
|
||||
# Load data
|
||||
###############################################################################
|
||||
corpus = get_lm_corpus(args.data, args.dataset)
|
||||
ntokens = len(corpus.vocab)
|
||||
args.n_token = ntokens
|
||||
|
||||
eval_batch_size = 10
|
||||
tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
va_iter = corpus.get_iterator('valid', eval_batch_size, args.eval_tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
te_iter = corpus.get_iterator('test', eval_batch_size, args.eval_tgt_len,
|
||||
device=device, ext_len=args.ext_len)
|
||||
|
||||
# adaptive softmax / embedding
|
||||
cutoffs, tie_projs = [], [False]
|
||||
if args.adaptive:
|
||||
assert args.dataset in ['wt103', 'lm1b']
|
||||
if args.dataset == 'wt103':
|
||||
cutoffs = [20000, 40000, 200000]
|
||||
tie_projs += [True] * len(cutoffs)
|
||||
elif args.dataset == 'lm1b':
|
||||
cutoffs = [60000, 100000, 640000]
|
||||
tie_projs += [False] * len(cutoffs)
|
||||
|
||||
###############################################################################
|
||||
# Build the model
|
||||
###############################################################################
|
||||
def init_weight(weight):
|
||||
if args.init == 'uniform':
|
||||
nn.init.uniform_(weight, -args.init_range, args.init_range)
|
||||
elif args.init == 'normal':
|
||||
nn.init.normal_(weight, 0.0, args.init_std)
|
||||
|
||||
def init_bias(bias):
|
||||
nn.init.constant_(bias, 0.0)
|
||||
|
||||
def weights_init(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Linear') != -1:
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
init_weight(m.weight)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init_bias(m.bias)
|
||||
elif classname.find('AdaptiveEmbedding') != -1:
|
||||
if hasattr(m, 'emb_projs'):
|
||||
for i in range(len(m.emb_projs)):
|
||||
if m.emb_projs[i] is not None:
|
||||
nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
|
||||
elif classname.find('Embedding') != -1:
|
||||
if hasattr(m, 'weight'):
|
||||
init_weight(m.weight)
|
||||
elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
|
||||
if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
|
||||
init_weight(m.cluster_weight)
|
||||
if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
|
||||
init_bias(m.cluster_bias)
|
||||
if hasattr(m, 'out_projs'):
|
||||
for i in range(len(m.out_projs)):
|
||||
if m.out_projs[i] is not None:
|
||||
nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
|
||||
elif classname.find('LayerNorm') != -1:
|
||||
if hasattr(m, 'weight'):
|
||||
nn.init.normal_(m.weight, 1.0, args.init_std)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
init_bias(m.bias)
|
||||
elif classname.find('TransformerLM') != -1:
|
||||
if hasattr(m, 'r_emb'):
|
||||
init_weight(m.r_emb)
|
||||
if hasattr(m, 'r_w_bias'):
|
||||
init_weight(m.r_w_bias)
|
||||
if hasattr(m, 'r_r_bias'):
|
||||
init_weight(m.r_r_bias)
|
||||
if hasattr(m, 'r_bias'):
|
||||
init_bias(m.r_bias)
|
||||
|
||||
def update_dropout(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Dropout') != -1:
|
||||
if hasattr(m, 'p'):
|
||||
m.p = args.dropout
|
||||
|
||||
def update_dropatt(m):
|
||||
if hasattr(m, 'dropatt'):
|
||||
m.dropatt.p = args.dropatt
|
||||
|
||||
if args.restart:
|
||||
with open(os.path.join(args.restart_dir, 'model.pt'), 'rb') as f:
|
||||
model = torch.load(f)
|
||||
if not args.fp16:
|
||||
model = model.float()
|
||||
model.apply(update_dropout)
|
||||
model.apply(update_dropatt)
|
||||
else:
|
||||
model = MemTransformerLM(ntokens, args.n_layer, args.n_head, args.d_model,
|
||||
args.d_head, args.d_inner, args.dropout, args.dropatt,
|
||||
tie_weight=args.tied, d_embed=args.d_embed, div_val=args.div_val,
|
||||
tie_projs=tie_projs, pre_lnorm=args.pre_lnorm, tgt_len=args.tgt_len,
|
||||
ext_len=args.ext_len, mem_len=args.mem_len, cutoffs=cutoffs,
|
||||
same_length=args.same_length, attn_type=args.attn_type,
|
||||
clamp_len=args.clamp_len, sample_softmax=args.sample_softmax)
|
||||
model.apply(weights_init)
|
||||
model.word_emb.apply(weights_init) # ensure embedding init is not overridden by out_layer in case of weight sharing
|
||||
args.n_all_param = sum([p.nelement() for p in model.parameters()])
|
||||
args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])
|
||||
|
||||
if args.fp16:
|
||||
model = model.half()
|
||||
|
||||
if args.multi_gpu:
|
||||
model = model.to(device)
|
||||
if args.gpu0_bsz >= 0:
|
||||
para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
|
||||
model, dim=1).to(device)
|
||||
else:
|
||||
para_model = nn.DataParallel(model, dim=1).to(device)
|
||||
else:
|
||||
para_model = model.to(device)
|
||||
|
||||
#### optimizer
|
||||
if args.optim.lower() == 'sgd':
|
||||
if args.sample_softmax > 0:
|
||||
dense_params, sparse_params = [], []
|
||||
for param in model.parameters():
|
||||
if param.size() == model.word_emb.weight.size():
|
||||
sparse_params.append(param)
|
||||
else:
|
||||
dense_params.append(param)
|
||||
optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
|
||||
optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
|
||||
else:
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr,
|
||||
momentum=args.mom)
|
||||
elif args.optim.lower() == 'adam':
|
||||
if args.sample_softmax > 0:
|
||||
dense_params, sparse_params = [], []
|
||||
for param in model.parameters():
|
||||
if param.size() == model.word_emb.weight.size():
|
||||
sparse_params.append(param)
|
||||
else:
|
||||
dense_params.append(param)
|
||||
optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
|
||||
optimizer = optim.Adam(dense_params, lr=args.lr)
|
||||
else:
|
||||
optimizer = optim.Adam(model.parameters(), lr=args.lr)
|
||||
elif args.optim.lower() == 'adagrad':
|
||||
optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
|
||||
|
||||
#### scheduler
|
||||
if args.scheduler == 'cosine':
|
||||
# here we do not set eta_min to lr_min to be backward compatible
|
||||
# because in previous versions eta_min is default to 0
|
||||
# rather than the default value of lr_min 1e-6
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
|
||||
args.max_step, eta_min=args.eta_min) # should use eta_min arg
|
||||
if args.sample_softmax > 0:
|
||||
scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(optimizer_sparse,
|
||||
args.max_step, eta_min=args.eta_min) # should use eta_min arg
|
||||
elif args.scheduler == 'inv_sqrt':
|
||||
# originally used for Transformer (in Attention is all you need)
|
||||
def lr_lambda(step):
|
||||
# return a multiplier instead of a learning rate
|
||||
if step == 0 and args.warmup_step == 0:
|
||||
return 1.
|
||||
else:
|
||||
return 1. / (step ** 0.5) if step > args.warmup_step \
|
||||
else step / (args.warmup_step ** 1.5)
|
||||
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
|
||||
elif args.scheduler == 'dev_perf':
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
|
||||
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
|
||||
if args.sample_softmax > 0:
|
||||
scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(optimizer_sparse,
|
||||
factor=args.decay_rate, patience=args.patience, min_lr=args.lr_min)
|
||||
elif args.scheduler == 'constant':
|
||||
pass
|
||||
|
||||
if args.cuda and args.fp16:
|
||||
# If args.dynamic_loss_scale is False, static_loss_scale will be used.
|
||||
# If args.dynamic_loss_scale is True, it will take precedence over static_loss_scale.
|
||||
optimizer = FP16_Optimizer(optimizer,
|
||||
static_loss_scale = args.static_loss_scale,
|
||||
dynamic_loss_scale = args.dynamic_loss_scale,
|
||||
dynamic_loss_args = {'init_scale': 2 ** 16})
|
||||
|
||||
if args.restart:
|
||||
if os.path.exists(os.path.join(args.restart_dir, 'optimizer.pt')):
|
||||
with open(os.path.join(args.restart_dir, 'optimizer.pt'), 'rb') as f:
|
||||
opt_state_dict = torch.load(f)
|
||||
optimizer.load_state_dict(opt_state_dict)
|
||||
else:
|
||||
print('Optimizer was not saved. Start from scratch.')
|
||||
|
||||
logging('=' * 100)
|
||||
for k, v in args.__dict__.items():
|
||||
logging(' - {} : {}'.format(k, v))
|
||||
logging('=' * 100)
|
||||
logging('#params = {}'.format(args.n_all_param))
|
||||
logging('#non emb params = {}'.format(args.n_nonemb_param))
|
||||
|
||||
###############################################################################
|
||||
# Training code
|
||||
###############################################################################
|
||||
|
||||
def evaluate(eval_iter):
|
||||
# Turn on evaluation mode which disables dropout.
|
||||
model.eval()
|
||||
|
||||
# If the model does not use memory at all, make the ext_len longer.
|
||||
# Otherwise, make the mem_len longer and keep the ext_len the same.
|
||||
if args.mem_len == 0:
|
||||
model.reset_length(args.eval_tgt_len,
|
||||
args.ext_len+args.tgt_len-args.eval_tgt_len, args.mem_len)
|
||||
else:
|
||||
model.reset_length(args.eval_tgt_len,
|
||||
args.ext_len, args.mem_len+args.tgt_len-args.eval_tgt_len)
|
||||
|
||||
# Evaluation
|
||||
total_len, total_loss = 0, 0.
|
||||
with torch.no_grad():
|
||||
mems = tuple()
|
||||
for i, (data, target, seq_len) in enumerate(eval_iter):
|
||||
if args.max_eval_steps > 0 and i >= args.max_eval_steps:
|
||||
break
|
||||
ret = model(data, target, *mems)
|
||||
loss, mems = ret[0], ret[1:]
|
||||
loss = loss.mean()
|
||||
total_loss += seq_len * loss.float().item()
|
||||
total_len += seq_len
|
||||
|
||||
# Switch back to the training mode
|
||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||
model.train()
|
||||
|
||||
return total_loss / total_len
|
||||
|
||||
|
||||
def train():
|
||||
# Turn on training mode which enables dropout.
|
||||
global train_step, train_loss, best_val_loss, eval_start_time, log_start_time
|
||||
model.train()
|
||||
if args.batch_chunk > 1:
|
||||
mems = [tuple() for _ in range(args.batch_chunk)]
|
||||
else:
|
||||
mems = tuple()
|
||||
train_iter = tr_iter.get_varlen_iter() if args.varlen else tr_iter
|
||||
for batch, (data, target, seq_len) in enumerate(train_iter):
|
||||
model.zero_grad()
|
||||
if args.batch_chunk > 1:
|
||||
data_chunks = torch.chunk(data, args.batch_chunk, 1)
|
||||
target_chunks = torch.chunk(target, args.batch_chunk, 1)
|
||||
for i in range(args.batch_chunk):
|
||||
data_i = data_chunks[i].contiguous()
|
||||
target_i = target_chunks[i].contiguous()
|
||||
ret = para_model(data_i, target_i, *mems[i])
|
||||
loss, mems[i] = ret[0], ret[1:]
|
||||
loss = loss.float().mean().type_as(loss) / args.batch_chunk
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
train_loss += loss.float().item()
|
||||
else:
|
||||
ret = para_model(data, target, *mems)
|
||||
loss, mems = ret[0], ret[1:]
|
||||
loss = loss.float().mean().type_as(loss)
|
||||
if args.fp16:
|
||||
optimizer.backward(loss)
|
||||
else:
|
||||
loss.backward()
|
||||
train_loss += loss.float().item()
|
||||
|
||||
if args.fp16:
|
||||
optimizer.clip_master_grads(args.clip)
|
||||
else:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
|
||||
|
||||
optimizer.step()
|
||||
if args.sample_softmax > 0:
|
||||
optimizer_sparse.step()
|
||||
|
||||
# step-wise learning rate annealing
|
||||
train_step += 1
|
||||
if args.scheduler in ['cosine', 'constant', 'dev_perf']:
|
||||
# linear warmup stage
|
||||
if train_step < args.warmup_step:
|
||||
curr_lr = args.lr * train_step / args.warmup_step
|
||||
optimizer.param_groups[0]['lr'] = curr_lr
|
||||
if args.sample_softmax > 0:
|
||||
optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
|
||||
else:
|
||||
if args.scheduler == 'cosine':
|
||||
scheduler.step(train_step)
|
||||
if args.sample_softmax > 0:
|
||||
scheduler_sparse.step(train_step)
|
||||
elif args.scheduler == 'inv_sqrt':
|
||||
scheduler.step(train_step)
|
||||
|
||||
if train_step % args.log_interval == 0:
|
||||
cur_loss = train_loss / args.log_interval
|
||||
elapsed = time.time() - log_start_time
|
||||
log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g} ' \
|
||||
'| ms/batch {:5.2f} | loss {:5.2f}'.format(
|
||||
epoch, train_step, batch+1, optimizer.param_groups[0]['lr'],
|
||||
elapsed * 1000 / args.log_interval, cur_loss)
|
||||
if args.dataset in ['enwik8', 'text8']:
|
||||
log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
|
||||
else:
|
||||
log_str += ' | ppl {:9.3f}'.format(math.exp(cur_loss))
|
||||
logging(log_str)
|
||||
train_loss = 0
|
||||
log_start_time = time.time()
|
||||
|
||||
if train_step % args.eval_interval == 0:
|
||||
val_loss = evaluate(va_iter)
|
||||
logging('-' * 100)
|
||||
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
|
||||
'| valid loss {:5.2f}'.format(
|
||||
train_step // args.eval_interval, train_step,
|
||||
(time.time() - eval_start_time), val_loss)
|
||||
if args.dataset in ['enwik8', 'text8']:
|
||||
log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
|
||||
else:
|
||||
log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
|
||||
logging(log_str)
|
||||
logging('-' * 100)
|
||||
# Save the model if the validation loss is the best we've seen so far.
|
||||
if not best_val_loss or val_loss < best_val_loss:
|
||||
if not args.debug:
|
||||
with open(os.path.join(args.work_dir, 'model.pt'), 'wb') as f:
|
||||
torch.save(model, f)
|
||||
with open(os.path.join(args.work_dir, 'optimizer.pt'), 'wb') as f:
|
||||
torch.save(optimizer.state_dict(), f)
|
||||
best_val_loss = val_loss
|
||||
|
||||
# dev-performance based learning rate annealing
|
||||
if args.scheduler == 'dev_perf':
|
||||
scheduler.step(val_loss)
|
||||
if args.sample_softmax > 0:
|
||||
scheduler_sparse.step(val_loss)
|
||||
|
||||
eval_start_time = time.time()
|
||||
|
||||
if train_step == args.max_step:
|
||||
break
|
||||
|
||||
# Loop over epochs.
|
||||
train_step = 0
|
||||
train_loss = 0
|
||||
best_val_loss = None
|
||||
|
||||
log_start_time = time.time()
|
||||
eval_start_time = time.time()
|
||||
|
||||
# At any point you can hit Ctrl + C to break out of training early.
|
||||
try:
|
||||
for epoch in itertools.count(start=1):
|
||||
train()
|
||||
if train_step == args.max_step:
|
||||
logging('-' * 100)
|
||||
logging('End of training')
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
logging('-' * 100)
|
||||
logging('Exiting from training early')
|
||||
|
||||
# Load the best saved model.
|
||||
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
|
||||
model = torch.load(f)
|
||||
para_model = model.to(device)
|
||||
|
||||
# Run on test data.
|
||||
test_loss = evaluate(te_iter)
|
||||
logging('=' * 100)
|
||||
if args.dataset in ['enwik8', 'text8']:
|
||||
logging('| End of training | test loss {:5.2f} | test bpc {:9.5f}'.format(
|
||||
test_loss, test_loss / math.log(2)))
|
||||
else:
|
||||
logging('| End of training | test loss {:5.2f} | test ppl {:9.3f}'.format(
|
||||
test_loss, math.exp(test_loss)))
|
||||
logging('=' * 100)
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class AdaptiveLogSoftmax(nn.Module):
|
||||
def __init__(self, in_features, n_classes, cutoffs, keep_order=False):
|
||||
super(AdaptiveLogSoftmax, self).__init__()
|
||||
|
||||
cutoffs = list(cutoffs)
|
||||
|
||||
if (cutoffs != sorted(cutoffs)) \
|
||||
or (min(cutoffs) <= 0) \
|
||||
or (max(cutoffs) >= (n_classes - 1)) \
|
||||
or (len(set(cutoffs)) != len(cutoffs)) \
|
||||
or any([int(c) != c for c in cutoffs]):
|
||||
|
||||
raise ValueError("cutoffs should be a sequence of unique, positive "
|
||||
"integers sorted in an increasing order, where "
|
||||
"each value is between 1 and n_classes-1")
|
||||
|
||||
self.in_features = in_features
|
||||
self.n_classes = n_classes
|
||||
self.cutoffs = cutoffs + [n_classes]
|
||||
|
||||
self.shortlist_size = self.cutoffs[0]
|
||||
self.n_clusters = len(self.cutoffs) - 1
|
||||
self.head_size = self.shortlist_size + self.n_clusters
|
||||
|
||||
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.in_features))
|
||||
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
|
||||
|
||||
self.keep_order = keep_order
|
||||
|
||||
|
||||
def forward(self, hidden, target, weight, bias, keep_order=False):
|
||||
if hidden.size(0) != target.size(0):
|
||||
raise RuntimeError('Input and target should have the same size '
|
||||
'in the batch dimension.')
|
||||
|
||||
head_weight = torch.cat(
|
||||
[weight[:self.shortlist_size], self.cluster_weight], dim=0)
|
||||
head_bias = torch.cat(
|
||||
[bias[:self.shortlist_size], self.cluster_bias], dim=0)
|
||||
|
||||
head_logit = F.linear(hidden, head_weight, bias=head_bias)
|
||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
nll = torch.zeros_like(target,
|
||||
dtype=hidden.dtype, device=hidden.DEVICE)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
for i in range(len(cutoff_values) - 1):
|
||||
l_idx, h_idx = cutoff_values[i], cutoff_values[i + 1]
|
||||
|
||||
mask_i = (target >= l_idx) & (target < h_idx)
|
||||
indices_i = mask_i.nonzero().squeeze()
|
||||
|
||||
if indices_i.numel() == 0:
|
||||
continue
|
||||
|
||||
target_i = target.index_select(0, indices_i) - l_idx
|
||||
head_logprob_i = head_logprob.index_select(0, indices_i)
|
||||
|
||||
if i == 0:
|
||||
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
else:
|
||||
weight_i = weight[l_idx:h_idx]
|
||||
bias_i = bias[l_idx:h_idx]
|
||||
|
||||
hidden_i = hidden.index_select(0, indices_i)
|
||||
|
||||
tail_logit_i = F.linear(hidden_i, weight_i, bias=bias_i)
|
||||
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
||||
|
||||
logprob_i = head_logprob_i[:, -i] \
|
||||
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
|
||||
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
||||
nll.index_copy_(0, indices_i, -logprob_i)
|
||||
else:
|
||||
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
|
||||
|
||||
offset += logprob_i.size(0)
|
||||
|
||||
return nll
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch
|
||||
from torch.nn.parallel._functions import Scatter
|
||||
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||
|
||||
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""
|
||||
Slices tensors into approximately equal chunks and
|
||||
distributes them across given GPUs. Duplicates
|
||||
references to objects that are not tensors.
|
||||
"""
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
try:
|
||||
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
|
||||
except:
|
||||
print('obj', obj.size())
|
||||
print('dim', dim)
|
||||
print('chunk_sizes', chunk_sizes)
|
||||
quit()
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
return list(map(list, zip(*map(scatter_map, obj))))
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return [obj for targets in target_gpus]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
# has a reference to the actual function scatter_map, which has references
|
||||
# to a closure that has a reference to the scatter_map cell (because the
|
||||
# fn is recursive). To avoid this reference cycle, we set the function to
|
||||
# None, clearing the cell
|
||||
try:
|
||||
return scatter_map(inputs)
|
||||
finally:
|
||||
scatter_map = None
|
||||
|
||||
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
|
||||
r"""Scatter with support for kwargs dictionary"""
|
||||
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
|
||||
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
||||
|
||||
class BalancedDataParallel(DataParallel):
|
||||
def __init__(self, gpu0_bsz, *args, **kwargs):
|
||||
self.gpu0_bsz = gpu0_bsz
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
if self.gpu0_bsz == 0:
|
||||
device_ids = self.device_ids[1:]
|
||||
else:
|
||||
device_ids = self.device_ids
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids)
|
||||
if self.gpu0_bsz == 0:
|
||||
replicas = replicas[1:]
|
||||
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, device_ids)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids):
|
||||
bsz = inputs[0].size(self.dim)
|
||||
num_dev = len(self.device_ids)
|
||||
gpu0_bsz = self.gpu0_bsz
|
||||
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
|
||||
if gpu0_bsz < bsz_unit:
|
||||
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
|
||||
delta = bsz - sum(chunk_sizes)
|
||||
for i in range(delta):
|
||||
chunk_sizes[i + 1] += 1
|
||||
if gpu0_bsz == 0:
|
||||
chunk_sizes = chunk_sizes[1:]
|
||||
else:
|
||||
return super().scatter(inputs, kwargs, device_ids)
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
|
||||
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
import functools
|
||||
import os, shutil
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def logging(s, log_path, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_:
|
||||
with open(log_path, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
|
||||
def get_logger(log_path, **kwargs):
|
||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||
|
||||
def create_exp_dir(dir_path, scripts_to_save=None, debug=False):
|
||||
if debug:
|
||||
print('Debug Mode : no experiment dir created')
|
||||
return functools.partial(logging, log_path=None, log_=False)
|
||||
|
||||
if not os.path.exists(dir_path):
|
||||
os.makedirs(dir_path)
|
||||
|
||||
print('Experiment dir : {}'.format(dir_path))
|
||||
if scripts_to_save is not None:
|
||||
script_path = os.path.join(dir_path, 'scripts')
|
||||
if not os.path.exists(script_path):
|
||||
os.makedirs(script_path)
|
||||
for script in scripts_to_save:
|
||||
dst_file = os.path.join(dir_path, 'scripts', os.path.basename(script))
|
||||
shutil.copyfile(script, dst_file)
|
||||
|
||||
return get_logger(log_path=os.path.join(dir_path, 'log.txt'))
|
||||
|
||||
def save_checkpoint(model, optimizer, path, epoch):
|
||||
torch.save(model, os.path.join(path, 'model_{}.pt'.format(epoch)))
|
||||
torch.save(optimizer.state_dict(), os.path.join(path, 'optimizer_{}.pt'.format(epoch)))
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
class LogUniformSampler(object):
|
||||
def __init__(self, range_max, n_sample):
|
||||
"""
|
||||
Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||
`P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||
|
||||
expected count can be approximated by 1 - (1 - p)^n
|
||||
and we use a numerically stable version -expm1(num_tries * log1p(-p))
|
||||
|
||||
Our implementation fixes num_tries at 2 * n_sample, and the actual #samples will vary from run to run
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.range_max = range_max
|
||||
log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||
self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
# print('P', self.dist.numpy().tolist()[-30:])
|
||||
|
||||
self.log_q = (- (-self.dist.double().log1p_() * 2 * n_sample).expm1_()).log_().float()
|
||||
|
||||
self.n_sample = n_sample
|
||||
|
||||
def sample(self, labels):
|
||||
"""
|
||||
labels: [b1, b2]
|
||||
Return
|
||||
true_log_probs: [b1, b2]
|
||||
samp_log_probs: [n_sample]
|
||||
neg_samples: [n_sample]
|
||||
"""
|
||||
|
||||
# neg_samples = torch.empty(0).long()
|
||||
n_sample = self.n_sample
|
||||
n_tries = 2 * n_sample
|
||||
|
||||
with torch.no_grad():
|
||||
neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
|
||||
device = labels.DEVICE
|
||||
neg_samples = neg_samples.to(device)
|
||||
true_log_probs = self.log_q[labels].to(device)
|
||||
samp_log_probs = self.log_q[neg_samples].to(device)
|
||||
return true_log_probs, samp_log_probs, neg_samples
|
||||
|
||||
def sample_logits(embedding, bias, labels, inputs, sampler):
|
||||
"""
|
||||
embedding: an nn.Embedding layer
|
||||
bias: [n_vocab]
|
||||
labels: [b1, b2]
|
||||
inputs: [b1, b2, n_emb]
|
||||
sampler: you may use a LogUniformSampler
|
||||
Return
|
||||
logits: [b1, b2, 1 + n_sample]
|
||||
"""
|
||||
true_log_probs, samp_log_probs, neg_samples = sampler.sample(labels)
|
||||
n_sample = neg_samples.size(0)
|
||||
b1, b2 = labels.size(0), labels.size(1)
|
||||
all_ids = torch.cat([labels.view(-1), neg_samples])
|
||||
all_w = embedding(all_ids)
|
||||
true_w = all_w[: -n_sample].view(b1, b2, -1)
|
||||
sample_w = all_w[- n_sample:].view(n_sample, -1)
|
||||
|
||||
all_b = bias[all_ids]
|
||||
true_b = all_b[: -n_sample].view(b1, b2)
|
||||
sample_b = all_b[- n_sample:]
|
||||
|
||||
hit = (labels[:, :, None] == neg_samples).detach()
|
||||
|
||||
true_logits = torch.einsum('ijk,ijk->ij',
|
||||
[true_w, inputs]) + true_b - true_log_probs
|
||||
sample_logits = torch.einsum('lk,ijk->ijl',
|
||||
[sample_w, inputs]) + sample_b - samp_log_probs
|
||||
sample_logits.masked_fill_(hit, -1e30)
|
||||
logits = torch.cat([true_logits[:, :, None], sample_logits], -1)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
# class LogUniformSampler(object):
|
||||
# def __init__(self, range_max, unique=False):
|
||||
# """
|
||||
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
|
||||
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
|
||||
# """
|
||||
# self.range_max = range_max
|
||||
# log_indices = torch.arange(1., range_max+2., 1.).log_()
|
||||
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
|
||||
|
||||
# self.unique = unique
|
||||
|
||||
# if self.unique:
|
||||
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
|
||||
|
||||
# def sample(self, n_sample, labels):
|
||||
# pos_sample, new_labels = labels.unique(return_inverse=True)
|
||||
# n_pos_sample = pos_sample.size(0)
|
||||
# n_neg_sample = n_sample - n_pos_sample
|
||||
|
||||
# if self.unique:
|
||||
# self.exclude_mask.index_fill_(0, pos_sample, 1)
|
||||
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
|
||||
# self.exclude_mask.index_fill_(0, pos_sample, 0)
|
||||
# else:
|
||||
# sample_dist = self.dist
|
||||
|
||||
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
|
||||
|
||||
# sample = torch.cat([pos_sample, neg_sample])
|
||||
# sample_prob = self.dist[sample]
|
||||
|
||||
# return new_labels, sample, sample_prob
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
S, B = 3, 4
|
||||
n_vocab = 10000
|
||||
n_sample = 5
|
||||
H = 32
|
||||
|
||||
labels = torch.LongTensor(S, B).random_(0, n_vocab)
|
||||
|
||||
# sampler = LogUniformSampler(n_vocab, unique=False)
|
||||
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
|
||||
|
||||
sampler = LogUniformSampler(n_vocab, unique=True)
|
||||
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
|
||||
|
||||
# print('true_probs', true_probs.numpy().tolist())
|
||||
# print('samp_probs', samp_probs.numpy().tolist())
|
||||
# print('neg_samples', neg_samples.numpy().tolist())
|
||||
|
||||
# print('sum', torch.sum(sampler.dist).item())
|
||||
|
||||
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
|
||||
|
||||
embedding = nn.Embedding(n_vocab, H)
|
||||
bias = torch.zeros(n_vocab)
|
||||
inputs = torch.Tensor(S, B, H).normal_()
|
||||
|
||||
logits, out_labels = sample_logits(embedding, bias, labels, inputs, sampler, n_sample)
|
||||
print('logits', logits.detach().numpy().tolist())
|
||||
print('logits shape', logits.size())
|
||||
print('out_labels', out_labels.detach().numpy().tolist())
|
||||
print('out_labels shape', out_labels.size())
|
||||
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
CUDA_MAJOR = int(torch.version.cuda.split('.')[0])
|
||||
CUDA_MINOR = int(torch.version.cuda.split('.')[1])
|
||||
|
||||
class ProjectedAdaptiveLogSoftmax(nn.Module):
|
||||
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1,
|
||||
keep_order=False):
|
||||
super(ProjectedAdaptiveLogSoftmax, self).__init__()
|
||||
|
||||
self.n_token = n_token
|
||||
self.d_embed = d_embed
|
||||
self.d_proj = d_proj
|
||||
|
||||
self.cutoffs = cutoffs + [n_token]
|
||||
self.cutoff_ends = [0] + self.cutoffs
|
||||
self.div_val = div_val
|
||||
|
||||
self.shortlist_size = self.cutoffs[0]
|
||||
self.n_clusters = len(self.cutoffs) - 1
|
||||
self.head_size = self.shortlist_size + self.n_clusters
|
||||
|
||||
if self.n_clusters > 0:
|
||||
self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed))
|
||||
self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters))
|
||||
|
||||
self.out_layers = nn.ModuleList()
|
||||
self.out_projs = nn.ParameterList()
|
||||
|
||||
if div_val == 1:
|
||||
for i in range(len(self.cutoffs)):
|
||||
if d_proj != d_embed:
|
||||
self.out_projs.append(
|
||||
nn.Parameter(torch.Tensor(d_proj, d_embed))
|
||||
)
|
||||
else:
|
||||
self.out_projs.append(None)
|
||||
|
||||
self.out_layers.append(nn.Linear(d_embed, n_token))
|
||||
else:
|
||||
for i in range(len(self.cutoffs)):
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1]
|
||||
d_emb_i = d_embed // (div_val ** i)
|
||||
|
||||
self.out_projs.append(
|
||||
nn.Parameter(torch.Tensor(d_proj, d_emb_i))
|
||||
)
|
||||
|
||||
self.out_layers.append(nn.Linear(d_emb_i, r_idx-l_idx))
|
||||
|
||||
self.keep_order = keep_order
|
||||
|
||||
def _compute_logit(self, hidden, weight, bias, proj):
|
||||
if proj is None:
|
||||
logit = F.linear(hidden, weight, bias=bias)
|
||||
else:
|
||||
# if CUDA_MAJOR <= 9 and CUDA_MINOR <= 1:
|
||||
proj_hid = F.linear(hidden, proj.t().contiguous())
|
||||
logit = F.linear(proj_hid, weight, bias=bias)
|
||||
# else:
|
||||
# logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t()))
|
||||
# if bias is not None:
|
||||
# logit = logit + bias
|
||||
|
||||
return logit
|
||||
|
||||
def forward(self, hidden, target, keep_order=False):
|
||||
'''
|
||||
hidden :: [len*bsz x d_proj]
|
||||
target :: [len*bsz]
|
||||
'''
|
||||
|
||||
if hidden.size(0) != target.size(0):
|
||||
raise RuntimeError('Input and target should have the same size '
|
||||
'in the batch dimension.')
|
||||
|
||||
if self.n_clusters == 0:
|
||||
logit = self._compute_logit(hidden, self.out_layers[0].weight,
|
||||
self.out_layers[0].bias, self.out_projs[0])
|
||||
nll = -F.log_softmax(logit, dim=-1) \
|
||||
.gather(1, target.unsqueeze(1)).squeeze(1)
|
||||
else:
|
||||
# construct weights and biases
|
||||
weights, biases = [], []
|
||||
for i in range(len(self.cutoffs)):
|
||||
if self.div_val == 1:
|
||||
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
|
||||
weight_i = self.out_layers[0].weight[l_idx:r_idx]
|
||||
bias_i = self.out_layers[0].bias[l_idx:r_idx]
|
||||
else:
|
||||
weight_i = self.out_layers[i].weight
|
||||
bias_i = self.out_layers[i].bias
|
||||
|
||||
if i == 0:
|
||||
weight_i = torch.cat(
|
||||
[weight_i, self.cluster_weight], dim=0)
|
||||
bias_i = torch.cat(
|
||||
[bias_i, self.cluster_bias], dim=0)
|
||||
|
||||
weights.append(weight_i)
|
||||
biases.append(bias_i)
|
||||
|
||||
head_weight, head_bias, head_proj = weights[0], biases[0], self.out_projs[0]
|
||||
|
||||
head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj)
|
||||
head_logprob = F.log_softmax(head_logit, dim=1)
|
||||
|
||||
nll = torch.zeros_like(target,
|
||||
dtype=hidden.dtype, device=hidden.DEVICE)
|
||||
|
||||
offset = 0
|
||||
cutoff_values = [0] + self.cutoffs
|
||||
for i in range(len(cutoff_values) - 1):
|
||||
l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1]
|
||||
|
||||
mask_i = (target >= l_idx) & (target < r_idx)
|
||||
indices_i = mask_i.nonzero().squeeze()
|
||||
|
||||
if indices_i.numel() == 0:
|
||||
continue
|
||||
|
||||
target_i = target.index_select(0, indices_i) - l_idx
|
||||
head_logprob_i = head_logprob.index_select(0, indices_i)
|
||||
|
||||
if i == 0:
|
||||
logprob_i = head_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
else:
|
||||
weight_i, bias_i, proj_i = weights[i], biases[i], self.out_projs[i]
|
||||
|
||||
hidden_i = hidden.index_select(0, indices_i)
|
||||
|
||||
tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i)
|
||||
tail_logprob_i = F.log_softmax(tail_logit_i, dim=1)
|
||||
|
||||
logprob_i = head_logprob_i[:, -i] \
|
||||
+ tail_logprob_i.gather(1, target_i[:,None]).squeeze(1)
|
||||
|
||||
if (hasattr(self, 'keep_order') and self.keep_order) or keep_order:
|
||||
nll.index_copy_(0, indices_i, -logprob_i)
|
||||
else:
|
||||
nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i)
|
||||
|
||||
offset += logprob_i.size(0)
|
||||
|
||||
return nll
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
import os
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
class Vocab(object):
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||
delimiter=None, vocab_file=None):
|
||||
self.counter = Counter()
|
||||
self.special = special
|
||||
self.min_freq = min_freq
|
||||
self.max_size = max_size
|
||||
self.lower_case = lower_case
|
||||
self.delimiter = delimiter
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
|
||||
# empty delimiter '' will evaluate False
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return symbols + ['<eos>']
|
||||
else:
|
||||
return symbols
|
||||
|
||||
def count_file(self, path, verbose=False, add_eos=False):
|
||||
if verbose: print('counting file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
|
||||
sents = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos)
|
||||
self.counter.update(symbols)
|
||||
sents.append(symbols)
|
||||
|
||||
return sents
|
||||
|
||||
def count_sents(self, sents, verbose=False):
|
||||
"""
|
||||
sents : a list of sentences, each a list of tokenized symbols
|
||||
"""
|
||||
if verbose: print('counting {} sents ...'.format(len(sents)))
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
self.counter.update(symbols)
|
||||
|
||||
def _build_from_file(self, vocab_file):
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
|
||||
with open(vocab_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
symb = line.strip().split()[0]
|
||||
self.add_symbol(symb)
|
||||
self.unk_idx = self.sym2idx['<UNK>']
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
self._build_from_file(self.vocab_file)
|
||||
print('final vocab size {}'.format(len(self)))
|
||||
else:
|
||||
print('building vocab with min_freq={}, max_size={}'.format(
|
||||
self.min_freq, self.max_size))
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
|
||||
for sym in self.special:
|
||||
self.add_special(sym)
|
||||
|
||||
for sym, cnt in self.counter.most_common(self.max_size):
|
||||
if cnt < self.min_freq: break
|
||||
self.add_symbol(sym)
|
||||
|
||||
print('final vocab size {} from {} unique tokens'.format(
|
||||
len(self), len(self.counter)))
|
||||
|
||||
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
|
||||
add_double_eos=False):
|
||||
if verbose: print('encoding file {} ...'.format(path))
|
||||
assert os.path.exists(path)
|
||||
encoded = []
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos,
|
||||
add_double_eos=add_double_eos)
|
||||
encoded.append(self.convert_to_tensor(symbols))
|
||||
|
||||
if ordered:
|
||||
encoded = torch.cat(encoded)
|
||||
|
||||
return encoded
|
||||
|
||||
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||
if verbose: print('encoding {} sents ...'.format(len(sents)))
|
||||
encoded = []
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
encoded.append(self.convert_to_tensor(symbols))
|
||||
|
||||
if ordered:
|
||||
encoded = torch.cat(encoded)
|
||||
|
||||
return encoded
|
||||
|
||||
def add_special(self, sym):
|
||||
if sym not in self.sym2idx:
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
|
||||
|
||||
def add_symbol(self, sym):
|
||||
if sym not in self.sym2idx:
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
|
||||
def get_sym(self, idx):
|
||||
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
|
||||
return self.idx2sym[idx]
|
||||
|
||||
def get_idx(self, sym):
|
||||
if sym in self.sym2idx:
|
||||
return self.sym2idx[sym]
|
||||
else:
|
||||
# print('encounter unk {}'.format(sym))
|
||||
assert '<eos>' not in sym
|
||||
assert hasattr(self, 'unk_idx')
|
||||
return self.sym2idx.get(sym, self.unk_idx)
|
||||
|
||||
def get_symbols(self, indices):
|
||||
return [self.get_sym(idx) for idx in indices]
|
||||
|
||||
def get_indices(self, symbols):
|
||||
return [self.get_idx(sym) for sym in symbols]
|
||||
|
||||
def convert_to_tensor(self, symbols):
|
||||
return torch.LongTensor(self.get_indices(symbols))
|
||||
|
||||
def convert_to_sent(self, indices, exclude=None):
|
||||
if exclude is None:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices])
|
||||
else:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
|
||||
## Introduction
|
||||
|
||||
This directory contains our TF implementation of Transformer-XL. Note that our state-of-the-art results reported in the paper were obtained by training the model on a large-scale TPU cluster, and our gpu codebase currently does not support distributed training. Here we provide two sets of hyperparameters and scripts:
|
||||
- `*large_tpu.sh` are for the SoTA setting on TPUs. These are exactly the commands we used to obtained our best results.
|
||||
- `*base_gpu.sh` are for the base models which can be run on a few GPUs.
|
||||
|
||||
|
||||
## Prerequisite
|
||||
|
||||
- Python 2.7
|
||||
- Tensorflow [1.12.0](https://github.com/tensorflow/tensorflow/releases/tag/v1.12.0)
|
||||
|
||||
|
||||
|
||||
## Obtain and evaluate pretrained SoTA models
|
||||
|
||||
#### 1. Download preprocessed data (vocab) & pretrained models
|
||||
|
||||
(a) Set your own `DATA_ROOT` in `sota/download.sh` (default to `./`), which will be the root diretory of downloaded model.
|
||||
|
||||
(b) Then, download the model & data by `bash sota/download.sh`. After downloading, the expected directory structure is as follows
|
||||
|
||||
```markdown
|
||||
pretrained_xl
|
||||
tf_enwik8/
|
||||
data/
|
||||
cache.pkl
|
||||
corpus-info.json
|
||||
model/
|
||||
checkpoint
|
||||
model.ckpt*
|
||||
tf_wt103/
|
||||
...
|
||||
...
|
||||
```
|
||||
|
||||
**Note**: we include preprocessed data in the download files to make sure the **same vocabulary** is used. Please see the code `tf/data_utils.py` to understand the data structure.
|
||||
|
||||
|
||||
|
||||
#### 2. Run evaluation scripts to replicate SoTA results on GPUs
|
||||
|
||||
- **enwik8**: modify the script `sota/enwik8.sh` accordingly (see below)
|
||||
- set `DATA_ROOT` to the same folder used in the download step (default to `./`)
|
||||
- set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 2 GPUs => about 60 mins
|
||||
- run the script: `bash sota/enwik8.sh`
|
||||
|
||||
- **lm1b**: modify the script `sota/lm1b.sh` accordingly (see below)
|
||||
- set `DATA_ROOT` to the same folder used in the download step (default to `./`)
|
||||
- set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 1 GPUs => less than 5 mins
|
||||
- run the script: `bash sota/lm1b.sh`
|
||||
|
||||
- **wt103**: modify the script `sota/wt103.sh` accordingly (see below)
|
||||
- set `DATA_ROOT` to the same folder used in the download step (default to `./`)
|
||||
- set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 1 GPUs => less than 5 mins
|
||||
- run the script: `bash sota/wt103.sh`
|
||||
|
||||
- **text8**: modify the script `sota/text8.sh` accordingly (see below)
|
||||
- set `DATA_ROOT` to the same folder used in the download step (default to `./`)
|
||||
- set `TEST_NUM_CORE ` (number of GPUs to use): we recommend 2 GPUs => about 60 mins
|
||||
- run the script: `bash sota/text8.sh`
|
||||
|
||||
|
||||
#### 3. Resources Needed for SoTA Model Training
|
||||
|
||||
We used 32, 32, 64, and 512 TPU cores for training our best models on enwik8, text8, wt103, and lm1b respectively. The training time for each model ranges from 2 to 5 days.
|
||||
|
||||
|
||||
|
||||
## Train "Transformer-XL" from scratch with GPUs or TPUs
|
||||
|
||||
### 1. Download raw data
|
||||
|
||||
`bash getdata.sh`
|
||||
|
||||
|
||||
|
||||
### 2. Preprocess, training and evaluation
|
||||
|
||||
For `dataset` in `[enwik8, lm1b, wt103, text8]`:
|
||||
|
||||
- check out `scripts/dataset_base_gpu.sh` for GPU training and evaluation
|
||||
- check out `scripts/dataset_large_tpu.sh` for TPU training and evaluation
|
||||
|
||||
|
||||
|
||||
#### (1) Preprocess raw data and create tfrecords
|
||||
|
||||
**NOTE**: The preprocessing for GPU and TPU are different. So, you have to run them separately.
|
||||
|
||||
GPU:
|
||||
|
||||
- create training and validation data: `bash scripts/dataset_bas_gpu.sh train_data`
|
||||
- create test data: `bash scripts/dataset_base_gpu.sh test_data`
|
||||
|
||||
TPU:
|
||||
|
||||
- Set the Google storage URL in `scripts/dataset_large_tpu.sh`:
|
||||
- `GSDATA`: data URL
|
||||
- `GSEXP`: experiment URL
|
||||
- create training and validation data: `bash scripts/dataset_large_tpu.sh train_data`
|
||||
- create test data: `bash scripts/dataset_large_tpu.sh test_data`
|
||||
|
||||
|
||||
|
||||
#### (2) Run training
|
||||
|
||||
Base models on GPUs:
|
||||
|
||||
- Modify the configurations in `scripts/dataset_base_gpu.sh` according to your needs.
|
||||
- `bash scripts/dataset_base_gpu.sh train`
|
||||
- If enough resources are available, increasing the model sizes (e.g., `N_LAYER`, `D_MODEL`, `D_EMBED`, `D_HEAD`, `D_INNER`) so that they are closer to the values defined in `scripts/dataset_large_tpu.sh`. Likewise, when resources are limited, decrease the model sizes. It is recommended to ensure that `D_MODEL == D_EMBED` and `D_MODEL == N_HEAD x D_HEAD`. When the model sizes increase, remember to increase `warmup_steps` accordingly to alleviate optimization difficulties.
|
||||
- Adjust the `NUM_CORE` parameter to reflect the number of GPUs to use.
|
||||
|
||||
Larger models on TPUs:
|
||||
|
||||
- Modify the configurations in `scripts/dataset_large_tpu.sh` according to your needs.
|
||||
- `bash scripts/dataset_large_tpu.sh train`
|
||||
|
||||
|
||||
|
||||
#### (3) Run evaluation
|
||||
|
||||
Base models on GPUs:
|
||||
|
||||
- `bash scripts/dataset_base_gpu.sh eval --eval_ckpt_path PATH_TO_CKPT`
|
||||
|
||||
Larger models on TPUs:
|
||||
|
||||
- `bash scripts/dataset_base_tpu.sh eval --eval_ckpt_path PATH_TO_CKPT`
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2018 The Tensor2Tensor Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Script to average values of variables in a list of checkpoint files."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import zip # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
flags = tf.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("checkpoints", "",
|
||||
"Comma-separated list of checkpoints to average.")
|
||||
flags.DEFINE_integer("num_last_checkpoints", 0,
|
||||
"Averages the last N saved checkpoints."
|
||||
" If the checkpoints flag is set, this is ignored.")
|
||||
flags.DEFINE_string("prefix", "",
|
||||
"Prefix (e.g., directory) to append to each checkpoint.")
|
||||
flags.DEFINE_string("output_path", "/tmp/averaged.ckpt",
|
||||
"Path to output the averaged checkpoint to.")
|
||||
|
||||
|
||||
def checkpoint_exists(path):
|
||||
return (tf.gfile.Exists(path) or tf.gfile.Exists(path + ".meta") or
|
||||
tf.gfile.Exists(path + ".index"))
|
||||
|
||||
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
if FLAGS.checkpoints:
|
||||
# Get the checkpoints list from flags and run some basic checks.
|
||||
checkpoints = [c.strip() for c in FLAGS.checkpoints.split(",")]
|
||||
checkpoints = [c for c in checkpoints if c]
|
||||
if not checkpoints:
|
||||
raise ValueError("No checkpoints provided for averaging.")
|
||||
if FLAGS.prefix:
|
||||
checkpoints = [FLAGS.prefix + c for c in checkpoints]
|
||||
else:
|
||||
assert FLAGS.num_last_checkpoints >= 1, "Must average at least one model"
|
||||
assert FLAGS.prefix, ("Prefix must be provided when averaging last"
|
||||
" N checkpoints")
|
||||
checkpoint_state = tf.train.get_checkpoint_state(
|
||||
os.path.dirname(FLAGS.prefix))
|
||||
# Checkpoints are ordered from oldest to newest.
|
||||
checkpoints = checkpoint_state.all_model_checkpoint_paths[
|
||||
-FLAGS.num_last_checkpoints:]
|
||||
|
||||
checkpoints = [c for c in checkpoints if checkpoint_exists(c)]
|
||||
if not checkpoints:
|
||||
if FLAGS.checkpoints:
|
||||
raise ValueError(
|
||||
"None of the provided checkpoints exist. %s" % FLAGS.checkpoints)
|
||||
else:
|
||||
raise ValueError("Could not find checkpoints at %s" %
|
||||
os.path.dirname(FLAGS.prefix))
|
||||
|
||||
# Read variables from all checkpoints and average them.
|
||||
tf.logging.info("Reading variables and averaging checkpoints:")
|
||||
for c in checkpoints:
|
||||
tf.logging.info("%s ", c)
|
||||
var_list = tf.contrib.framework.list_variables(checkpoints[0])
|
||||
var_values, var_dtypes = {}, {}
|
||||
for (name, shape) in var_list:
|
||||
if not name.startswith("global_step"):
|
||||
var_values[name] = np.zeros(shape)
|
||||
for checkpoint in checkpoints:
|
||||
reader = tf.contrib.framework.load_checkpoint(checkpoint)
|
||||
for name in var_values:
|
||||
tensor = reader.get_tensor(name)
|
||||
var_dtypes[name] = tensor.dtype
|
||||
var_values[name] += tensor
|
||||
tf.logging.info("Read from checkpoint %s", checkpoint)
|
||||
for name in var_values: # Average.
|
||||
var_values[name] /= len(checkpoints)
|
||||
|
||||
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
|
||||
tf_vars = [
|
||||
tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
|
||||
for v in var_values
|
||||
]
|
||||
placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
|
||||
assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
|
||||
global_step = tf.Variable(
|
||||
0, name="global_step", trainable=False, dtype=tf.int64)
|
||||
saver = tf.train.Saver(tf.all_variables())
|
||||
|
||||
# Build a model consisting only of variables, set them to the average values.
|
||||
with tf.Session() as sess:
|
||||
sess.run(tf.initialize_all_variables())
|
||||
for p, assign_op, (name, value) in zip(placeholders, assign_ops,
|
||||
six.iteritems(var_values)):
|
||||
sess.run(assign_op, {p: value})
|
||||
# Use the built saver to save the averaged checkpoint.
|
||||
saver.save(sess, FLAGS.output_path, global_step=global_step)
|
||||
|
||||
tf.logging.info("Averaged checkpoints saved in %s", FLAGS.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
|
|
@ -1,586 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
from collections import Counter, OrderedDict
|
||||
import pickle
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
|
||||
from absl import flags
|
||||
import tensorflow as tf
|
||||
from vocabulary import Vocab
|
||||
|
||||
from tensorflow.gfile import Exists as exists
|
||||
from tensorflow.gfile import MakeDirs as makedirs
|
||||
from tensorflow.gfile import Glob as glob
|
||||
|
||||
|
||||
def _preprocess(shard, train, vocab, save_dir, cutoffs, bin_sizes, bsz, tgt_len,
|
||||
num_core_per_host, use_tpu, num_shuffle):
|
||||
file_names = []
|
||||
num_batch = 0
|
||||
|
||||
path = train[shard]
|
||||
data_shard = vocab.encode_file(path, ordered=False, add_double_eos=True)
|
||||
|
||||
for shuffle in range(num_shuffle):
|
||||
basename = "train-{:03d}-{:02d}".format(shard, shuffle)
|
||||
print("Processing shard {} shuffle {}".format(shard, shuffle))
|
||||
|
||||
np.random.shuffle(data_shard)
|
||||
file_name, num_batch_shuffle = create_ordered_tfrecords(
|
||||
save_dir, basename, np.concatenate(data_shard), bsz, tgt_len,
|
||||
num_core_per_host, cutoffs, bin_sizes, use_tpu=use_tpu)
|
||||
file_names.append(file_name)
|
||||
num_batch += num_batch_shuffle
|
||||
|
||||
return file_names, num_batch
|
||||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, path, dataset, *args, **kwargs):
|
||||
self.dataset = dataset
|
||||
self.vocab = Vocab(*args, **kwargs)
|
||||
|
||||
if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
|
||||
self.vocab.count_file(os.path.join(path, "train.txt"))
|
||||
self.vocab.count_file(os.path.join(path, "valid.txt"))
|
||||
self.vocab.count_file(os.path.join(path, "test.txt"))
|
||||
elif self.dataset == "wt103":
|
||||
self.vocab.count_file(os.path.join(path, "train.txt"))
|
||||
elif self.dataset == "lm1b":
|
||||
train_path_pattern = os.path.join(
|
||||
path, "1-billion-word-language-modeling-benchmark-r13output",
|
||||
"training-monolingual.tokenized.shuffled", "news.en-*")
|
||||
train_paths = glob(train_path_pattern)
|
||||
|
||||
# the vocab will load from file when build_vocab() is called
|
||||
# for train_path in sorted(train_paths):
|
||||
# self.vocab.count_file(train_path, verbose=True)
|
||||
|
||||
self.vocab.build_vocab()
|
||||
|
||||
if self.dataset in ["ptb", "wt2", "wt103"]:
|
||||
self.train = self.vocab.encode_file(
|
||||
os.path.join(path, "train.txt"), ordered=True)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, "valid.txt"), ordered=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, "test.txt"), ordered=True)
|
||||
elif self.dataset in ["enwik8", "text8"]:
|
||||
self.train = self.vocab.encode_file(
|
||||
os.path.join(path, "train.txt"), ordered=True, add_eos=False)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, "test.txt"), ordered=True, add_eos=False)
|
||||
elif self.dataset == "lm1b":
|
||||
self.train = train_paths
|
||||
valid_path = os.path.join(path, "valid.txt")
|
||||
test_path = valid_path
|
||||
self.valid = self.vocab.encode_file(
|
||||
valid_path, ordered=True, add_double_eos=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
test_path, ordered=True, add_double_eos=True)
|
||||
|
||||
if self.dataset == "wt103":
|
||||
self.cutoffs = [0, 20000, 40000, 200000] + [len(self.vocab)]
|
||||
elif self.dataset == "lm1b":
|
||||
self.cutoffs = [0, 60000, 100000, 640000] + [len(self.vocab)]
|
||||
else:
|
||||
self.cutoffs = []
|
||||
|
||||
|
||||
def convert_to_tfrecords(self, split, save_dir, bsz, tgt_len,
|
||||
num_core_per_host, **kwargs):
|
||||
FLAGS = kwargs.get('FLAGS')
|
||||
|
||||
file_names = []
|
||||
use_tpu = FLAGS.use_tpu and not (split == "test" and num_core_per_host == 1)
|
||||
|
||||
if use_tpu:
|
||||
record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
|
||||
split, bsz, tgt_len, num_core_per_host)
|
||||
else:
|
||||
record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
|
||||
split, bsz, tgt_len)
|
||||
|
||||
record_info_path = os.path.join(save_dir, record_name)
|
||||
|
||||
if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
|
||||
data = getattr(self, split)
|
||||
bin_sizes = get_bin_sizes(
|
||||
data, bsz // num_core_per_host, tgt_len, self.cutoffs)
|
||||
file_name, num_batch = create_ordered_tfrecords(
|
||||
save_dir, split, data, bsz, tgt_len, num_core_per_host,
|
||||
self.cutoffs, bin_sizes,
|
||||
num_passes=FLAGS.num_passes if split == 'train' and use_tpu else 1,
|
||||
use_tpu=use_tpu)
|
||||
file_names.append(file_name)
|
||||
elif self.dataset == "lm1b":
|
||||
bin_sizes = get_bin_sizes(
|
||||
self.valid, bsz // num_core_per_host, tgt_len, self.cutoffs)
|
||||
if split == "train":
|
||||
np.random.seed(123456)
|
||||
num_batch = 0
|
||||
|
||||
if FLAGS.num_procs > 1:
|
||||
_preprocess_wrapper = partial(_preprocess,
|
||||
train=self.train, vocab=self.vocab, save_dir=save_dir,
|
||||
cutoffs=self.cutoffs, bin_sizes=bin_sizes, bsz=bsz,
|
||||
tgt_len=tgt_len, num_core_per_host=num_core_per_host,
|
||||
use_tpu=use_tpu, num_shuffle=FLAGS.num_shuffle)
|
||||
|
||||
pool = mp.Pool(processes=FLAGS.num_procs)
|
||||
results = pool.map(_preprocess_wrapper, range(len(self.train)))
|
||||
for res in results:
|
||||
file_names.extend(res[0])
|
||||
num_batch += res[1]
|
||||
else:
|
||||
for shard, path in enumerate(self.train):
|
||||
data_shard = self.vocab.encode_file(path, ordered=False,
|
||||
add_double_eos=True)
|
||||
|
||||
num_shuffle = FLAGS.num_shuffle
|
||||
|
||||
for shuffle in range(num_shuffle):
|
||||
print("Processing shard {} shuffle {}".format(shard, shuffle))
|
||||
basename = "train-{:03d}-{:02d}".format(shard, shuffle)
|
||||
np.random.shuffle(data_shard)
|
||||
file_name, num_batch_ = create_ordered_tfrecords(
|
||||
save_dir, basename, np.concatenate(data_shard), bsz, tgt_len,
|
||||
num_core_per_host,
|
||||
self.cutoffs, bin_sizes, use_tpu=use_tpu)
|
||||
file_names.append(file_name)
|
||||
num_batch += num_batch_
|
||||
|
||||
else:
|
||||
file_name, num_batch = create_ordered_tfrecords(
|
||||
save_dir, split, getattr(self, split), bsz, tgt_len,
|
||||
num_core_per_host,
|
||||
self.cutoffs, bin_sizes, use_tpu=use_tpu)
|
||||
file_names.append(file_name)
|
||||
|
||||
with open(record_info_path, "w") as fp:
|
||||
record_info = {
|
||||
"filenames": file_names,
|
||||
"bin_sizes": bin_sizes,
|
||||
"num_batch": num_batch
|
||||
}
|
||||
json.dump(record_info, fp)
|
||||
|
||||
|
||||
def get_bin_sizes(data, batch_size, tgt_len, cutoffs, std_mult=[2.5, 2.5, 2.5]):
|
||||
"""
|
||||
Note: the `batch_size` here should be per-core batch size
|
||||
"""
|
||||
bin_sizes = []
|
||||
|
||||
def _nearest_to_eight(x): # so that it's faster on TPUs
|
||||
y = x - x % 8
|
||||
return y + 8 if x % 8 >= 4 else max(8, y)
|
||||
|
||||
if cutoffs:
|
||||
num_batch = len(data) // batch_size // tgt_len
|
||||
|
||||
data = data[:batch_size * num_batch * tgt_len]
|
||||
data = data.reshape(batch_size, num_batch, tgt_len)
|
||||
|
||||
tot = batch_size * tgt_len
|
||||
for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):
|
||||
mask = (data >= left) * (data < right)
|
||||
percents = mask.astype(np.float64).sum(2).sum(0) / tot
|
||||
mean = np.mean(percents)
|
||||
std = np.std(percents)
|
||||
|
||||
bin_size = int(math.ceil(tgt_len * batch_size * (mean + std_mult[b] * std)))
|
||||
bin_size = _nearest_to_eight(bin_size)
|
||||
bin_sizes.append(bin_size)
|
||||
|
||||
return bin_sizes
|
||||
|
||||
|
||||
def _int64_feature(values):
|
||||
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
|
||||
|
||||
def _float_feature(values):
|
||||
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
|
||||
|
||||
def batchify(data, batch_size, num_passes):
|
||||
"""
|
||||
if use_tpu = True: num_passes > 1
|
||||
|
||||
Since TPU training requires entire [bsz x tgt_len] chunks, it can discard
|
||||
as many as `bsz * tgt_len` tokens in training. When `bsz` and `tgt_len` are
|
||||
both large, as in the case of TPU training for Transformer-XL, the problem
|
||||
may lead to detectable performance drop.
|
||||
|
||||
Here, we use multiple randomly shifted copies to deal with this problem.
|
||||
"""
|
||||
if num_passes > 1:
|
||||
data_len = len(data)
|
||||
double_data = np.concatenate([data, data])
|
||||
data_list = []
|
||||
for i in range(num_passes):
|
||||
start = np.random.randint(0, data_len)
|
||||
data_list.append(double_data[start:start+data_len])
|
||||
data = np.concatenate(data_list)
|
||||
|
||||
num_step = len(data) // batch_size
|
||||
data = data[:batch_size * num_step]
|
||||
data = data.reshape(batch_size, num_step)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def create_ordered_tfrecords(save_dir, basename, data, batch_size, tgt_len,
|
||||
num_core_per_host, cutoffs=[], bin_sizes=[],
|
||||
num_passes=1, use_tpu=False):
|
||||
|
||||
if use_tpu:
|
||||
file_name = "{}.bsz-{}.tlen-{}.core-{}.tfrecords".format(
|
||||
basename, batch_size, tgt_len, num_core_per_host)
|
||||
else:
|
||||
file_name = "{}.bsz-{}.tlen-{}.tfrecords".format(
|
||||
basename, batch_size, tgt_len)
|
||||
|
||||
save_path = os.path.join(save_dir, file_name)
|
||||
record_writer = tf.python_io.TFRecordWriter(save_path)
|
||||
|
||||
batched_data = batchify(data, batch_size, num_passes)
|
||||
|
||||
num_batch = 0
|
||||
# for t in range(0, batched_data.shape[1] - tgt_len - 1, tgt_len):
|
||||
for t in range(0, batched_data.shape[1] - 1, tgt_len):
|
||||
cur_tgt_len = min(batched_data.shape[1] - 1 - t, tgt_len)
|
||||
# drop the remainder if use tpu
|
||||
if use_tpu and cur_tgt_len < tgt_len:
|
||||
break
|
||||
if num_batch % 500 == 0:
|
||||
print(" processing batch {}".format(num_batch))
|
||||
for idx in range(batch_size):
|
||||
inputs = batched_data[idx, t:t + cur_tgt_len]
|
||||
labels = batched_data[idx, t + 1:t + cur_tgt_len + 1]
|
||||
|
||||
# features dict
|
||||
feature = {
|
||||
"inputs": _int64_feature(inputs),
|
||||
"labels": _int64_feature(labels),
|
||||
}
|
||||
|
||||
if len(cutoffs) > 0 and use_tpu:
|
||||
# validate `bin_sizes` and `cutoffs`
|
||||
assert len(cutoffs) - len(bin_sizes) == 2, \
|
||||
"len(cutoffs) - len(bin_sizes) != 2"
|
||||
|
||||
# mask for bin 0
|
||||
left, right = cutoffs[:2]
|
||||
inp_mask = ((inputs >= left) * (inputs < right)).astype(np.float32)
|
||||
tgt_mask = ((labels >= left) * (labels < right)).astype(np.float32)
|
||||
|
||||
feature["inp_mask"] = _float_feature(inp_mask)
|
||||
feature["tgt_mask"] = _float_feature(tgt_mask)
|
||||
|
||||
# refresh `inp_cnts` and `tgt_cnts` for each TPU core
|
||||
if idx % (batch_size // num_core_per_host) == 0:
|
||||
inp_cnts = [0] * len(bin_sizes)
|
||||
tgt_cnts = [0] * len(bin_sizes)
|
||||
|
||||
head_labels = np.copy(labels)
|
||||
inp_pos_per_bin, tgt_pos_per_bin = [], []
|
||||
for b, (left, right) in enumerate(zip(cutoffs[1:-1], cutoffs[2:])):
|
||||
inp_pos = np.where((inputs >= left) * (inputs < right))[0]
|
||||
tgt_pos = np.where((labels >= left) * (labels < right))[0]
|
||||
inp_pos_per_bin.append(inp_pos)
|
||||
tgt_pos_per_bin.append(tgt_pos)
|
||||
|
||||
head_labels[tgt_pos] = cutoffs[1] + b
|
||||
|
||||
feature["head_labels"] = _int64_feature(head_labels)
|
||||
|
||||
# permutation feature
|
||||
def _add_perm_feature(feature, pos_per_bin, cnts, prefix):
|
||||
for b, pos in enumerate(pos_per_bin):
|
||||
idx_tuple = []
|
||||
for p in pos:
|
||||
if cnts[b] < bin_sizes[b]:
|
||||
idx_tuple.append([p, cnts[b]])
|
||||
cnts[b] += 1
|
||||
else:
|
||||
break
|
||||
|
||||
n_tup = len(idx_tuple)
|
||||
tup = np.array(idx_tuple).reshape(n_tup * 2)
|
||||
|
||||
feature["{}_cnt_{}".format(prefix, b)] = _int64_feature([n_tup])
|
||||
feature["{}_tup_{}".format(prefix, b)] = _int64_feature(tup)
|
||||
|
||||
_add_perm_feature(feature, inp_pos_per_bin, inp_cnts, "inp")
|
||||
_add_perm_feature(feature, tgt_pos_per_bin, tgt_cnts, "tgt")
|
||||
|
||||
example = tf.train.Example(features=tf.train.Features(feature=feature))
|
||||
record_writer.write(example.SerializeToString())
|
||||
|
||||
num_batch += 1
|
||||
|
||||
record_writer.close()
|
||||
print("Done writing {}. batches: {}".format(file_name, num_batch))
|
||||
|
||||
return file_name, num_batch
|
||||
|
||||
|
||||
def get_lm_corpus(data_dir, dataset):
|
||||
fn = os.path.join(data_dir, "cache.pkl")
|
||||
|
||||
if exists(fn):
|
||||
print("Loading cached dataset...")
|
||||
with open(fn, "rb") as fp:
|
||||
corpus = pickle.load(fp)
|
||||
else:
|
||||
print("Producing dataset...")
|
||||
kwargs = {}
|
||||
if dataset in ["wt103", "wt2"]:
|
||||
kwargs["special"] = ["<eos>"]
|
||||
kwargs["lower_case"] = False
|
||||
elif dataset == "ptb":
|
||||
kwargs["special"] = ["<eos>"]
|
||||
kwargs["lower_case"] = True
|
||||
elif dataset == "lm1b":
|
||||
kwargs["special"] = []
|
||||
kwargs["lower_case"] = False
|
||||
kwargs["vocab_file"] = os.path.join(data_dir, "1b_word_vocab.txt")
|
||||
elif dataset in ["enwik8", "text8"]:
|
||||
pass
|
||||
|
||||
corpus = Corpus(data_dir, dataset, **kwargs)
|
||||
|
||||
print("Saving dataset...")
|
||||
with open(fn, "wb") as fp:
|
||||
pickle.dump(corpus, fp, protocol=2)
|
||||
|
||||
corpus_info = {
|
||||
"vocab_size" : len(corpus.vocab),
|
||||
"cutoffs" : corpus.cutoffs,
|
||||
"dataset" : corpus.dataset
|
||||
}
|
||||
with open(os.path.join(data_dir, "corpus-info.json"), "w") as fp:
|
||||
json.dump(corpus_info, fp)
|
||||
|
||||
return corpus
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv # Unused
|
||||
|
||||
corpus = get_lm_corpus(FLAGS.data_dir, FLAGS.dataset)
|
||||
|
||||
save_dir = os.path.join(FLAGS.data_dir, "tfrecords")
|
||||
if not exists(save_dir):
|
||||
makedirs(save_dir)
|
||||
|
||||
# test mode
|
||||
if FLAGS.per_host_test_bsz > 0:
|
||||
corpus.convert_to_tfrecords("test", save_dir, FLAGS.per_host_test_bsz,
|
||||
FLAGS.tgt_len, FLAGS.num_core_per_host,
|
||||
FLAGS=FLAGS)
|
||||
return
|
||||
|
||||
for split, batch_size in zip(
|
||||
["train", "valid"],
|
||||
[FLAGS.per_host_train_bsz, FLAGS.per_host_valid_bsz]):
|
||||
|
||||
if batch_size <= 0: continue
|
||||
print("Converting {} set...".format(split))
|
||||
corpus.convert_to_tfrecords(split, save_dir, batch_size, FLAGS.tgt_len,
|
||||
FLAGS.num_core_per_host, FLAGS=FLAGS)
|
||||
|
||||
|
||||
def load_record_info(record_info_dir, split, per_host_bsz, tgt_len,
|
||||
num_core_per_host, use_tpu):
|
||||
if use_tpu:
|
||||
record_name = "record_info-{}.bsz-{}.tlen-{}.core-{}.json".format(
|
||||
split, per_host_bsz, tgt_len, num_core_per_host)
|
||||
else:
|
||||
record_name = "record_info-{}.bsz-{}.tlen-{}.json".format(
|
||||
split, per_host_bsz, tgt_len)
|
||||
|
||||
record_info_path = os.path.join(record_info_dir, record_name)
|
||||
with open(record_info_path, "r") as fp:
|
||||
record_info = json.load(fp)
|
||||
|
||||
return record_info
|
||||
|
||||
def get_input_fn(record_info_dir, split, per_host_bsz, tgt_len,
|
||||
num_core_per_host, num_hosts=1, use_tpu=False):
|
||||
"""Creates input function."""
|
||||
record_info = load_record_info(record_info_dir, split, per_host_bsz, tgt_len,
|
||||
num_core_per_host, use_tpu=use_tpu)
|
||||
|
||||
file_names = record_info["filenames"]
|
||||
bin_sizes = record_info["bin_sizes"]
|
||||
num_batch = record_info["num_batch"]
|
||||
|
||||
tf.logging.info("[{}] File names {}".format(split, file_names))
|
||||
|
||||
def input_fn(params):
|
||||
# per-core batch size
|
||||
per_core_bsz = params["batch_size"]
|
||||
|
||||
# data_dir could be a remote path, e.g., a google storage url
|
||||
data_dir = params["data_dir"]
|
||||
|
||||
def parser(record):
|
||||
# preprocess "inp_perm" and "tgt_perm"
|
||||
def _process_perm_feature(example, prefix):
|
||||
for b in range(len(bin_sizes)):
|
||||
cnt = example.pop("{}_cnt_{}".format(prefix, b))[0]
|
||||
tup = example.pop("{}_tup_{}".format(prefix, b))
|
||||
|
||||
tup = tf.reshape(
|
||||
tf.sparse_tensor_to_dense(tup),
|
||||
shape=[cnt, 2])
|
||||
|
||||
# tf.float32
|
||||
perm = tf.sparse_to_dense(
|
||||
sparse_indices=tup,
|
||||
output_shape=[tgt_len, bin_sizes[b]],
|
||||
sparse_values=1.0,
|
||||
default_value=0.0)
|
||||
|
||||
example["{}_perm_{}".format(prefix, b)] = perm
|
||||
|
||||
# whether allow the last batch with a potentially shorter length
|
||||
if use_tpu:
|
||||
record_spec = {
|
||||
"inputs": tf.FixedLenFeature([tgt_len], tf.int64),
|
||||
"labels": tf.FixedLenFeature([tgt_len], tf.int64),
|
||||
}
|
||||
else:
|
||||
record_spec = {
|
||||
"inputs": tf.VarLenFeature(tf.int64),
|
||||
"labels": tf.VarLenFeature(tf.int64),
|
||||
}
|
||||
|
||||
# permutation related features
|
||||
if bin_sizes and use_tpu:
|
||||
# tf.float32
|
||||
record_spec["inp_mask"] = tf.FixedLenFeature([tgt_len], tf.float32)
|
||||
record_spec["tgt_mask"] = tf.FixedLenFeature([tgt_len], tf.float32)
|
||||
|
||||
record_spec["head_labels"] = tf.FixedLenFeature([tgt_len], tf.int64)
|
||||
|
||||
for b in range(len(bin_sizes)):
|
||||
record_spec["inp_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64)
|
||||
record_spec["inp_tup_{}".format(b)] = tf.VarLenFeature(tf.int64)
|
||||
record_spec["tgt_cnt_{}".format(b)] = tf.FixedLenFeature([1], tf.int64)
|
||||
record_spec["tgt_tup_{}".format(b)] = tf.VarLenFeature(tf.int64)
|
||||
|
||||
# retrieve serialized example
|
||||
example = tf.parse_single_example(
|
||||
serialized=record,
|
||||
features=record_spec)
|
||||
|
||||
# transform permutation tuples to permutation matrices
|
||||
if bin_sizes and use_tpu:
|
||||
_process_perm_feature(example, "inp")
|
||||
_process_perm_feature(example, "tgt")
|
||||
|
||||
# cast int64 into int32
|
||||
# cast sparse to dense
|
||||
for key in list(example.keys()):
|
||||
val = example[key]
|
||||
if tf.keras.backend.is_sparse(val):
|
||||
val = tf.sparse.to_dense(val)
|
||||
if val.dtype == tf.int64:
|
||||
val = tf.to_int32(val)
|
||||
example[key] = val
|
||||
|
||||
if use_tpu:
|
||||
return example
|
||||
else:
|
||||
return example["inputs"], example["labels"]
|
||||
|
||||
file_paths = []
|
||||
for file_name in file_names:
|
||||
file_path = os.path.join(data_dir, file_name)
|
||||
file_paths.append(file_path)
|
||||
|
||||
if split == "train":
|
||||
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
|
||||
if len(file_paths) > 1:
|
||||
dataset = dataset.shuffle(len(file_paths)).repeat()
|
||||
dataset = tf.data.TFRecordDataset(dataset)
|
||||
elif num_hosts > 1:
|
||||
host_id = params["context"].current_host
|
||||
# drop the remaining batches
|
||||
num_batch_per_host = num_batch // num_hosts
|
||||
|
||||
my_start_sample_id = (host_id * num_batch_per_host * num_core_per_host *
|
||||
per_core_bsz)
|
||||
my_sample_num = num_batch_per_host * num_core_per_host * per_core_bsz
|
||||
dataset = tf.data.TFRecordDataset(dataset).skip(
|
||||
my_start_sample_id).take(my_sample_num)
|
||||
else:
|
||||
dataset = tf.data.TFRecordDataset(dataset)
|
||||
|
||||
dataset = dataset.map(parser).cache().repeat()
|
||||
dataset = dataset.batch(per_core_bsz, drop_remainder=True)
|
||||
dataset = dataset.prefetch(num_core_per_host * per_core_bsz)
|
||||
else:
|
||||
# do not shuffle, repeat or cache in evaluation
|
||||
dataset = tf.data.Dataset.from_tensor_slices(file_paths)
|
||||
dataset = tf.data.TFRecordDataset(dataset)
|
||||
dataset = dataset.map(parser)
|
||||
dataset = dataset.batch(per_core_bsz, drop_remainder=True)
|
||||
|
||||
return dataset
|
||||
|
||||
if split == "train" and num_hosts > 1:
|
||||
record_info["num_batch"] = num_batch // num_hosts
|
||||
|
||||
return input_fn, record_info
|
||||
|
||||
def get_corpus_info(corpus_info_path):
|
||||
with open(corpus_info_path, "r") as fp:
|
||||
corpus_info = json.load(fp)
|
||||
return corpus_info
|
||||
|
||||
if __name__ == "__main__":
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("data_dir", None,
|
||||
help="Location of the data corpus")
|
||||
flags.DEFINE_enum("dataset", "wt103",
|
||||
["ptb", "wt2", "wt103", "lm1b", "enwik8", "text8"],
|
||||
help="Dataset name.")
|
||||
flags.DEFINE_integer("per_host_train_bsz", 60,
|
||||
help="train batch size each host")
|
||||
flags.DEFINE_integer("per_host_valid_bsz", 60,
|
||||
help="valid batch size each host")
|
||||
flags.DEFINE_integer("per_host_test_bsz", 0,
|
||||
help="If > 0, enter test mode and process test set only."
|
||||
"Otherwise, process train and dev sets only.")
|
||||
flags.DEFINE_integer("tgt_len", 70,
|
||||
help="number of tokens to predict")
|
||||
flags.DEFINE_integer("max_batch", -1,
|
||||
help="run in debug mode")
|
||||
flags.DEFINE_integer("num_core_per_host", 8,
|
||||
help="8 for TPU v2.")
|
||||
flags.DEFINE_bool("debug", default=False,
|
||||
help="Process only the first batch without shuffle for lm1b.")
|
||||
flags.DEFINE_integer("num_procs", 1,
|
||||
help="number of processes")
|
||||
flags.DEFINE_integer("num_passes", 10,
|
||||
help="number of passes when use_tpu=True")
|
||||
flags.DEFINE_integer("num_shuffle", 4,
|
||||
help="number of shuffles for lm1b")
|
||||
flags.DEFINE_bool("use_tpu", True,
|
||||
help="use tpu")
|
||||
|
||||
tf.app.run(main)
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
import os
|
||||
import tensorflow as tf
|
||||
|
||||
def assign_to_gpu(gpu=0, ps_dev="/DEVICE:CPU:0"):
|
||||
def _assign(op):
|
||||
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
|
||||
if node_def.op == "Variable":
|
||||
return ps_dev
|
||||
else:
|
||||
return "/gpu:%d" % gpu
|
||||
return _assign
|
||||
|
||||
|
||||
def average_grads_and_vars(tower_grads_and_vars):
|
||||
def average_dense(grad_and_vars):
|
||||
if len(grad_and_vars) == 1:
|
||||
return grad_and_vars[0][0]
|
||||
|
||||
grad = grad_and_vars[0][0]
|
||||
for g, _ in grad_and_vars[1:]:
|
||||
grad += g
|
||||
return grad / len(grad_and_vars)
|
||||
|
||||
def average_sparse(grad_and_vars):
|
||||
if len(grad_and_vars) == 1:
|
||||
return grad_and_vars[0][0]
|
||||
|
||||
indices = []
|
||||
values = []
|
||||
for g, _ in grad_and_vars:
|
||||
indices += [g.indices]
|
||||
values += [g.values]
|
||||
indices = tf.concat(indices, 0)
|
||||
values = tf.concat(values, 0) / len(grad_and_vars)
|
||||
return tf.IndexedSlices(values, indices, grad_and_vars[0][0].dense_shape)
|
||||
|
||||
average_grads_and_vars = []
|
||||
for grad_and_vars in zip(*tower_grads_and_vars):
|
||||
if grad_and_vars[0][0] is None:
|
||||
grad = None
|
||||
elif isinstance(grad_and_vars[0][0], tf.IndexedSlices):
|
||||
grad = average_sparse(grad_and_vars)
|
||||
else:
|
||||
grad = average_dense(grad_and_vars)
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads_and_vars.append(grad_and_var)
|
||||
return average_grads_and_vars
|
||||
|
||||
|
||||
def load_from_checkpoint(saver, logdir):
|
||||
sess = tf.get_default_session()
|
||||
ckpt = tf.train.get_checkpoint_state(logdir)
|
||||
if ckpt and ckpt.model_checkpoint_path:
|
||||
if os.path.isabs(ckpt.model_checkpoint_path):
|
||||
# Restores from checkpoint with absolute path.
|
||||
saver.restore(sess, ckpt.model_checkpoint_path)
|
||||
else:
|
||||
# Restores from checkpoint with relative path.
|
||||
saver.restore(sess, os.path.join(logdir, ckpt.model_checkpoint_path))
|
||||
return True
|
||||
return False
|
||||
|
|
@ -1,546 +0,0 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def positional_embedding(pos_seq, inv_freq, bsz=None):
|
||||
sinusoid_inp = tf.einsum('i,j->ij', pos_seq, inv_freq)
|
||||
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
|
||||
if bsz is not None:
|
||||
return tf.tile(pos_emb[:, None, :], [1, bsz, 1])
|
||||
else:
|
||||
return pos_emb[:, None, :]
|
||||
|
||||
|
||||
def positionwise_FF(inp, d_model, d_inner, dropout, kernel_initializer,
|
||||
scope='ff', is_training=True):
|
||||
output = inp
|
||||
with tf.variable_scope(scope):
|
||||
output = tf.layers.dense(inp, d_inner, activation=tf.nn.relu,
|
||||
kernel_initializer=kernel_initializer,
|
||||
name='layer_1')
|
||||
output = tf.layers.dropout(output, dropout, training=is_training,
|
||||
name='drop_1')
|
||||
output = tf.layers.dense(output, d_model,
|
||||
kernel_initializer=kernel_initializer,
|
||||
name='layer_2')
|
||||
output = tf.layers.dropout(output, dropout, training=is_training,
|
||||
name='drop_2')
|
||||
output = tf.contrib.layers.layer_norm(output + inp, begin_norm_axis=-1)
|
||||
return output
|
||||
|
||||
|
||||
def rel_shift(x):
|
||||
x_size = tf.shape(x)
|
||||
|
||||
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
|
||||
x = tf.reshape(x, [x_size[1] + 1, x_size[0], x_size[2], x_size[3]])
|
||||
x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1])
|
||||
x = tf.reshape(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def rel_multihead_attn(w, r, r_w_bias, r_r_bias, attn_mask, mems, d_model,
|
||||
n_head, d_head, dropout, dropatt, is_training,
|
||||
kernel_initializer, scope='rel_attn'):
|
||||
scale = 1 / (d_head ** 0.5)
|
||||
with tf.variable_scope(scope):
|
||||
qlen = tf.shape(w)[0]
|
||||
rlen = tf.shape(r)[0]
|
||||
bsz = tf.shape(w)[1]
|
||||
|
||||
cat = tf.concat([mems, w],
|
||||
0) if mems is not None and mems.shape.ndims > 1 else w
|
||||
w_heads = tf.layers.dense(cat, 3 * n_head * d_head, use_bias=False,
|
||||
kernel_initializer=kernel_initializer, name='qkv')
|
||||
r_head_k = tf.layers.dense(r, n_head * d_head, use_bias=False,
|
||||
kernel_initializer=kernel_initializer, name='r')
|
||||
|
||||
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, -1)
|
||||
w_head_q = w_head_q[-qlen:]
|
||||
|
||||
klen = tf.shape(w_head_k)[0]
|
||||
|
||||
w_head_q = tf.reshape(w_head_q, [qlen, bsz, n_head, d_head])
|
||||
w_head_k = tf.reshape(w_head_k, [klen, bsz, n_head, d_head])
|
||||
w_head_v = tf.reshape(w_head_v, [klen, bsz, n_head, d_head])
|
||||
|
||||
r_head_k = tf.reshape(r_head_k, [rlen, n_head, d_head])
|
||||
|
||||
rw_head_q = w_head_q + r_w_bias
|
||||
rr_head_q = w_head_q + r_r_bias
|
||||
|
||||
AC = tf.einsum('ibnd,jbnd->ijbn', rw_head_q, w_head_k)
|
||||
BD = tf.einsum('ibnd,jnd->ijbn', rr_head_q, r_head_k)
|
||||
BD = rel_shift(BD)
|
||||
|
||||
attn_score = (AC + BD) * scale
|
||||
attn_mask_t = attn_mask[:, :, None, None]
|
||||
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
|
||||
|
||||
attn_prob = tf.nn.softmax(attn_score, 1)
|
||||
attn_prob = tf.layers.dropout(attn_prob, dropatt, training=is_training)
|
||||
|
||||
attn_vec = tf.einsum('ijbn,jbnd->ibnd', attn_prob, w_head_v)
|
||||
size_t = tf.shape(attn_vec)
|
||||
attn_vec = tf.reshape(attn_vec, [size_t[0], size_t[1], n_head * d_head])
|
||||
|
||||
attn_out = tf.layers.dense(attn_vec, d_model, use_bias=False,
|
||||
kernel_initializer=kernel_initializer, name='o')
|
||||
attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
|
||||
|
||||
output = tf.contrib.layers.layer_norm(attn_out + w, begin_norm_axis=-1)
|
||||
return output
|
||||
|
||||
|
||||
def embedding_lookup(lookup_table, x, use_tpu=True):
|
||||
if use_tpu:
|
||||
n_token = tf.shape(lookup_table)[0]
|
||||
one_hot_idx = tf.one_hot(x, n_token)
|
||||
if one_hot_idx.shape.ndims == 2:
|
||||
return tf.einsum('nd,in->id', lookup_table, one_hot_idx)
|
||||
else:
|
||||
return tf.einsum('nd,ibn->ibd', lookup_table, one_hot_idx)
|
||||
else:
|
||||
return tf.nn.embedding_lookup(lookup_table, x)
|
||||
|
||||
|
||||
def mask_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
|
||||
proj_initializer, div_val=1,
|
||||
proj_same_dim=True,
|
||||
scope='adaptive_embed', **kwargs):
|
||||
emb_scale = d_proj ** 0.5
|
||||
with tf.variable_scope(scope):
|
||||
if div_val == 1:
|
||||
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
|
||||
initializer=initializer)
|
||||
y = embedding_lookup(lookup_table, x, use_tpu=False)
|
||||
if d_proj != d_embed:
|
||||
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
y = tf.einsum('ibe,ed->ibd', y, proj_W)
|
||||
else:
|
||||
proj_W = None
|
||||
ret_params = [lookup_table, proj_W]
|
||||
else:
|
||||
tables, projs = [], []
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
x_size = tf.shape(x)
|
||||
y = tf.zeros([x_size[0], x_size[1], d_proj])
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
mask = (x >= l_idx) & (x < r_idx)
|
||||
cur_x = tf.boolean_mask(x, mask) - l_idx
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
lookup_table = tf.get_variable('lookup_table',
|
||||
[r_idx - l_idx, cur_d_embed],
|
||||
initializer=initializer)
|
||||
cur_y = embedding_lookup(lookup_table, cur_x, use_tpu=False)
|
||||
if d_proj == cur_d_embed and not proj_same_dim:
|
||||
proj_W = None
|
||||
else:
|
||||
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
cur_y = tf.einsum('id,de->ie', cur_y, proj_W)
|
||||
mask_idx = tf.to_int64(tf.where(mask))
|
||||
y += tf.scatter_nd(mask_idx, cur_y, tf.to_int64(tf.shape(y)))
|
||||
tables.append(lookup_table)
|
||||
projs.append(proj_W)
|
||||
ret_params = [tables, projs]
|
||||
|
||||
y *= emb_scale
|
||||
return y, ret_params
|
||||
|
||||
|
||||
def mul_adaptive_embedding_lookup(x, n_token, d_embed, d_proj, cutoffs, initializer,
|
||||
proj_initializer, div_val=1, perms=None,
|
||||
proj_same_dim=True,
|
||||
scope='adaptive_embed'):
|
||||
"""
|
||||
perms: If None, first compute W = W1 x W2 (projection for each bin),
|
||||
and then compute X x W (embedding lookup). If not None,
|
||||
use bin-based embedding lookup with max_bin_size defined by
|
||||
the shape of perms.
|
||||
"""
|
||||
emb_scale = d_proj ** 0.5
|
||||
with tf.variable_scope(scope):
|
||||
if div_val == 1:
|
||||
lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
|
||||
initializer=initializer)
|
||||
y = embedding_lookup(lookup_table, x)
|
||||
if d_proj != d_embed:
|
||||
proj_W = tf.get_variable('proj_W', [d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
y = tf.einsum('ibe,ed->ibd', y, proj_W)
|
||||
else:
|
||||
proj_W = None
|
||||
ret_params = [lookup_table, proj_W]
|
||||
else:
|
||||
tables, projs = [], []
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
x_size = tf.shape(x)
|
||||
if perms is None:
|
||||
cat_lookup = []
|
||||
else:
|
||||
cat_lookup = tf.zeros([x_size[0], x_size[1], d_proj])
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
lookup_table = tf.get_variable('lookup_table',
|
||||
[r_idx - l_idx, cur_d_embed],
|
||||
initializer=initializer)
|
||||
if cur_d_embed == d_proj and not proj_same_dim:
|
||||
proj_W = None
|
||||
else:
|
||||
proj_W = tf.get_variable('proj_W', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
if perms is None:
|
||||
cat_lookup.append(tf.einsum('ie,ed->id', lookup_table, proj_W))
|
||||
else:
|
||||
# speed up the computation of the first bin
|
||||
# also save some meory
|
||||
if i == 0:
|
||||
cur_y = embedding_lookup(lookup_table, tf.minimum(x, r_idx - 1))
|
||||
if proj_W is not None:
|
||||
cur_y = tf.einsum('ibe,ed->ibd', cur_y, proj_W)
|
||||
cur_y *= perms[i][:, :, None]
|
||||
cat_lookup += cur_y
|
||||
else:
|
||||
cur_x = tf.einsum('ib,ibk->k', tf.to_float(x - l_idx), perms[i])
|
||||
cur_x = tf.to_int32(cur_x)
|
||||
cur_y = embedding_lookup(lookup_table, cur_x)
|
||||
if proj_W is not None:
|
||||
cur_y = tf.einsum('ke,ed->kd', cur_y, proj_W)
|
||||
cat_lookup += tf.einsum('kd,ibk->ibd', cur_y, perms[i])
|
||||
tables.append(lookup_table)
|
||||
projs.append(proj_W)
|
||||
if perms is None:
|
||||
cat_lookup = tf.concat(cat_lookup, 0)
|
||||
y = embedding_lookup(cat_lookup, x)
|
||||
else:
|
||||
y = cat_lookup
|
||||
ret_params = [tables, projs]
|
||||
|
||||
y *= emb_scale
|
||||
return y, ret_params
|
||||
|
||||
|
||||
def mask_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
||||
params, tie_projs,
|
||||
initializer=None, proj_initializer=None,
|
||||
div_val=1, scope='adaptive_softmax',
|
||||
proj_same_dim=True,
|
||||
return_mean=True, **kwargs):
|
||||
def _logit(x, W, b, proj):
|
||||
y = x
|
||||
if proj is not None:
|
||||
y = tf.einsum('ibd,ed->ibe', y, proj)
|
||||
return tf.einsum('ibd,nd->ibn', y, W) + b
|
||||
|
||||
params_W, params_projs = params[0], params[1]
|
||||
|
||||
def _gather_logprob(logprob, target):
|
||||
lp_size = tf.shape(logprob)
|
||||
r = tf.range(lp_size[0])
|
||||
idx = tf.stack([r, target], 1)
|
||||
return tf.gather_nd(logprob, idx)
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
if len(cutoffs) == 0:
|
||||
softmax_b = tf.get_variable('bias', [n_token],
|
||||
initializer=tf.zeros_initializer())
|
||||
output = _logit(hidden, params_W, softmax_b, params_projs)
|
||||
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
||||
logits=output)
|
||||
else:
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
nll = tf.zeros_like(target, dtype=tf.float32)
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
mask = (target >= l_idx) & (target < r_idx)
|
||||
mask_idx = tf.where(mask)
|
||||
cur_target = tf.boolean_mask(target, mask) - l_idx
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
|
||||
if div_val == 1:
|
||||
cur_W = params_W[l_idx: r_idx]
|
||||
else:
|
||||
cur_W = params_W[i]
|
||||
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
||||
initializer=tf.zeros_initializer())
|
||||
if tie_projs[i]:
|
||||
if div_val == 1:
|
||||
cur_proj = params_projs
|
||||
else:
|
||||
cur_proj = params_projs[i]
|
||||
else:
|
||||
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
||||
cur_proj = None
|
||||
else:
|
||||
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
if i == 0:
|
||||
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
||||
initializer=tf.zeros_initializer())
|
||||
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
||||
initializer=tf.zeros_initializer())
|
||||
cur_W = tf.concat([cur_W, cluster_W], 0)
|
||||
cur_b = tf.concat([cur_b, cluster_b], 0)
|
||||
|
||||
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
||||
head_logprob = tf.nn.log_softmax(head_logit)
|
||||
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
||||
cur_logprob = _gather_logprob(cur_head_logprob, cur_target)
|
||||
else:
|
||||
cur_head_logprob = tf.boolean_mask(head_logprob, mask)
|
||||
cur_hidden = tf.boolean_mask(hidden, mask)
|
||||
tail_logit = tf.squeeze(_logit(
|
||||
cur_hidden[None], cur_W, cur_b, cur_proj), 0)
|
||||
tail_logprob = tf.nn.log_softmax(tail_logit)
|
||||
cur_logprob = (cur_head_logprob[:, cutoff_ends[1] + i - 1] +
|
||||
_gather_logprob(tail_logprob, cur_target))
|
||||
nll += tf.scatter_nd(mask_idx, -cur_logprob,
|
||||
tf.to_int64(tf.shape(nll)))
|
||||
if return_mean:
|
||||
nll = tf.reduce_mean(nll)
|
||||
return nll
|
||||
|
||||
|
||||
def mul_adaptive_logsoftmax(hidden, target, n_token, d_embed, d_proj, cutoffs,
|
||||
params, tie_projs,
|
||||
initializer=None, proj_initializer=None,
|
||||
div_val=1, perms=None, proj_same_dim=True,
|
||||
scope='adaptive_softmax',
|
||||
**kwargs):
|
||||
def _logit(x, W, b, proj):
|
||||
y = x
|
||||
if x.shape.ndims == 3:
|
||||
if proj is not None:
|
||||
y = tf.einsum('ibd,ed->ibe', y, proj)
|
||||
return tf.einsum('ibd,nd->ibn', y, W) + b
|
||||
else:
|
||||
if proj is not None:
|
||||
y = tf.einsum('id,ed->ie', y, proj)
|
||||
return tf.einsum('id,nd->in', y, W) + b
|
||||
|
||||
params_W, params_projs = params[0], params[1]
|
||||
|
||||
with tf.variable_scope(scope):
|
||||
if len(cutoffs) == 0:
|
||||
softmax_b = tf.get_variable('bias', [n_token],
|
||||
initializer=tf.zeros_initializer())
|
||||
output = _logit(hidden, params_W, softmax_b, params_projs)
|
||||
nll = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target,
|
||||
logits=output)
|
||||
nll = tf.reduce_mean(nll)
|
||||
else:
|
||||
total_loss, total_cnt = 0, 0
|
||||
cutoff_ends = [0] + cutoffs + [n_token]
|
||||
for i in range(len(cutoff_ends) - 1):
|
||||
with tf.variable_scope('cutoff_{}'.format(i)):
|
||||
l_idx, r_idx = cutoff_ends[i], cutoff_ends[i + 1]
|
||||
|
||||
cur_d_embed = d_embed // (div_val ** i)
|
||||
|
||||
if div_val == 1:
|
||||
cur_W = params_W[l_idx: r_idx]
|
||||
else:
|
||||
cur_W = params_W[i]
|
||||
cur_b = tf.get_variable('b', [r_idx - l_idx],
|
||||
initializer=tf.zeros_initializer())
|
||||
if tie_projs[i]:
|
||||
if div_val == 1:
|
||||
cur_proj = params_projs
|
||||
else:
|
||||
cur_proj = params_projs[i]
|
||||
else:
|
||||
if (div_val == 1 or not proj_same_dim) and d_proj == cur_d_embed:
|
||||
cur_proj = None
|
||||
else:
|
||||
cur_proj = tf.get_variable('proj', [cur_d_embed, d_proj],
|
||||
initializer=proj_initializer)
|
||||
|
||||
if i == 0:
|
||||
cluster_W = tf.get_variable('cluster_W', [len(cutoffs), d_embed],
|
||||
initializer=tf.zeros_initializer())
|
||||
cluster_b = tf.get_variable('cluster_b', [len(cutoffs)],
|
||||
initializer=tf.zeros_initializer())
|
||||
cur_W = tf.concat([cur_W, cluster_W], 0)
|
||||
cur_b = tf.concat([cur_b, cluster_b], 0)
|
||||
|
||||
head_logit = _logit(hidden, cur_W, cur_b, cur_proj)
|
||||
|
||||
head_target = kwargs.get("head_target")
|
||||
head_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=head_target,
|
||||
logits=head_logit)
|
||||
|
||||
masked_loss = head_nll * perms[i]
|
||||
total_loss += tf.reduce_sum(masked_loss)
|
||||
total_cnt += tf.reduce_sum(perms[i])
|
||||
|
||||
# head_logprob = tf.nn.log_softmax(head_logit)
|
||||
|
||||
# final_logprob = head_logprob * perms[i][:, :, None]
|
||||
# final_target = tf.one_hot(target, tf.shape(head_logprob)[2])
|
||||
# total_loss -= tf.einsum('ibn,ibn->', final_logprob, final_target)
|
||||
# total_cnt += tf.reduce_sum(perms[i])
|
||||
else:
|
||||
cur_head_nll = tf.einsum('ib,ibk->k', head_nll, perms[i])
|
||||
|
||||
cur_hidden = tf.einsum('ibd,ibk->kd', hidden, perms[i])
|
||||
tail_logit = _logit(cur_hidden, cur_W, cur_b, cur_proj)
|
||||
|
||||
tail_target = tf.einsum('ib,ibk->k', tf.to_float(target - l_idx),
|
||||
perms[i])
|
||||
tail_nll = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
labels=tf.to_int32(tail_target),
|
||||
logits=tail_logit)
|
||||
|
||||
sum_nll = cur_head_nll + tail_nll
|
||||
mask = tf.reduce_sum(perms[i], [0, 1])
|
||||
|
||||
masked_loss = sum_nll * mask
|
||||
total_loss += tf.reduce_sum(masked_loss)
|
||||
total_cnt += tf.reduce_sum(mask)
|
||||
|
||||
nll = total_loss / total_cnt
|
||||
|
||||
return nll
|
||||
|
||||
|
||||
def _create_mask(qlen, mlen, same_length=False):
|
||||
attn_mask = tf.ones([qlen, qlen])
|
||||
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
|
||||
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
|
||||
attn_mask_pad = tf.zeros([qlen, mlen])
|
||||
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
|
||||
if same_length:
|
||||
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
|
||||
ret = tf.concat([ret[:, :qlen] + mask_l - mask_dia, ret[:, qlen:]], 1)
|
||||
return ret
|
||||
|
||||
def _cache_mem(curr_out, prev_mem, mem_len=None):
|
||||
if mem_len is None or prev_mem is None:
|
||||
new_mem = curr_out
|
||||
elif mem_len == 0:
|
||||
return prev_mem
|
||||
else:
|
||||
new_mem = tf.concat([prev_mem, curr_out], 0)[- mem_len:]
|
||||
|
||||
return tf.stop_gradient(new_mem)
|
||||
|
||||
|
||||
def transformer(dec_inp, target, mems, n_token, n_layer, d_model, d_embed,
|
||||
n_head, d_head, d_inner, dropout, dropatt,
|
||||
initializer, is_training, proj_initializer=None,
|
||||
mem_len=None, cutoffs=[], div_val=1, tie_projs=[],
|
||||
same_length=False, clamp_len=-1, use_tpu=True,
|
||||
input_perms=None, target_perms=None, head_target=None,
|
||||
untie_r=False, proj_same_dim=True,
|
||||
scope='transformer'):
|
||||
"""
|
||||
cutoffs: a list of python int. Cutoffs for adaptive softmax.
|
||||
tie_projs: a list of python bools. Whether to tie the projections.
|
||||
use_tpu: if True, use one_hot in embedding lookup and bin-based implementation
|
||||
of adaptive softmax.
|
||||
perms: a list of tensors. Each tensor should of size [len, bsz, bin_size].
|
||||
Only used in the adaptive setting.
|
||||
"""
|
||||
new_mems = []
|
||||
with tf.variable_scope(scope):
|
||||
if untie_r:
|
||||
r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
|
||||
initializer=initializer)
|
||||
r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
|
||||
initializer=initializer)
|
||||
else:
|
||||
r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
|
||||
initializer=initializer)
|
||||
r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
|
||||
initializer=initializer)
|
||||
|
||||
qlen = tf.shape(dec_inp)[0]
|
||||
mlen = tf.shape(mems[0])[0] if mems is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
if proj_initializer is None:
|
||||
proj_initializer = initializer
|
||||
lookup_fn = (mul_adaptive_embedding_lookup if use_tpu else
|
||||
mask_adaptive_embedding_lookup)
|
||||
embeddings, shared_params = lookup_fn(
|
||||
x=dec_inp,
|
||||
n_token=n_token,
|
||||
d_embed=d_embed,
|
||||
d_proj=d_model,
|
||||
cutoffs=cutoffs,
|
||||
initializer=initializer,
|
||||
proj_initializer=proj_initializer,
|
||||
div_val= div_val,
|
||||
perms=input_perms,
|
||||
proj_same_dim=proj_same_dim)
|
||||
|
||||
attn_mask = _create_mask(qlen, mlen, same_length)
|
||||
|
||||
pos_seq = tf.range(klen - 1, -1, -1.0)
|
||||
if clamp_len > 0:
|
||||
pos_seq = tf.minimum(pos_seq, clamp_len)
|
||||
inv_freq = 1 / (10000 ** (tf.range(0, d_model, 2.0) / d_model))
|
||||
pos_emb = positional_embedding(pos_seq, inv_freq)
|
||||
|
||||
output = tf.layers.dropout(embeddings, dropout, training=is_training)
|
||||
pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
|
||||
|
||||
if mems is None:
|
||||
mems = [None] * n_layer
|
||||
|
||||
for i in range(n_layer):
|
||||
# cache new mems
|
||||
new_mems.append(_cache_mem(output, mems[i], mem_len))
|
||||
|
||||
with tf.variable_scope('layer_{}'.format(i)):
|
||||
output = rel_multihead_attn(
|
||||
w=output,
|
||||
r=pos_emb,
|
||||
r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
|
||||
r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
|
||||
attn_mask=attn_mask,
|
||||
mems=mems[i],
|
||||
d_model=d_model,
|
||||
n_head=n_head,
|
||||
d_head=d_head,
|
||||
dropout=dropout,
|
||||
dropatt=dropatt,
|
||||
is_training=is_training,
|
||||
kernel_initializer=initializer)
|
||||
output = positionwise_FF(
|
||||
inp=output,
|
||||
d_model=d_model,
|
||||
d_inner=d_inner,
|
||||
dropout=dropout,
|
||||
kernel_initializer=initializer,
|
||||
is_training=is_training)
|
||||
|
||||
output = tf.layers.dropout(output, dropout, training=is_training)
|
||||
|
||||
logsoftmax_fn = (mul_adaptive_logsoftmax if use_tpu else
|
||||
mask_adaptive_logsoftmax)
|
||||
loss = logsoftmax_fn(
|
||||
hidden=output,
|
||||
target=target,
|
||||
n_token=n_token,
|
||||
d_embed=d_embed,
|
||||
d_proj=d_model,
|
||||
cutoffs=cutoffs,
|
||||
params=shared_params,
|
||||
tie_projs=tie_projs,
|
||||
initializer=initializer,
|
||||
proj_initializer=proj_initializer,
|
||||
div_val=div_val,
|
||||
perms=target_perms,
|
||||
head_target=head_target,
|
||||
proj_same_dim=proj_same_dim)
|
||||
return loss, new_mems
|
||||
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=../data/enwik8/
|
||||
|
||||
# Model
|
||||
N_LAYER=12
|
||||
D_MODEL=512
|
||||
D_EMBED=512
|
||||
N_HEAD=8
|
||||
D_HEAD=64
|
||||
D_INNER=2048
|
||||
|
||||
# Training
|
||||
TGT_LEN=512
|
||||
MEM_LEN=512
|
||||
|
||||
BSZ=24
|
||||
NUM_CORE=4
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=80
|
||||
TEST_MEM_LEN=2100
|
||||
TEST_CLAMP_LEN=820
|
||||
|
||||
TEST_BSZ=10
|
||||
TEST_NUM_CORE=1
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${BSZ} \
|
||||
--per_host_valid_bsz=${BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-enwik8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.1 \
|
||||
--dropatt=0.0 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=0 \
|
||||
--train_steps=400000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=200 \
|
||||
--save_steps=4000 \
|
||||
--do_train=True \
|
||||
--do_eval=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-enwik8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Path
|
||||
LOCAL_DIR=../data/enwik8/
|
||||
GSDATA=
|
||||
GSEXP=
|
||||
|
||||
# TPU setting
|
||||
NUM_HOST=2
|
||||
NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
TEST_NUM_HOST=1
|
||||
TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
# Model
|
||||
N_LAYER=24
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=8
|
||||
D_HEAD=128
|
||||
D_INNER=3072
|
||||
|
||||
# Training
|
||||
TGT_LEN=768
|
||||
MEM_LEN=768
|
||||
TRAIN_BSZ=64
|
||||
VALID_BSZ=64
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=128
|
||||
TEST_MEM_LEN=3800
|
||||
TEST_CLAMP_LEN=1000
|
||||
TEST_BSZ=16
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${TRAIN_BSZ} \
|
||||
--per_host_valid_bsz=${VALID_BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--num_passes=10 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/
|
||||
|
||||
SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/
|
||||
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/enwik8-tfrecords/
|
||||
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/enwik8-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/enwik8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.15 \
|
||||
--dropatt=0.15 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=4000 \
|
||||
--train_steps=400000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${TRAIN_BSZ} \
|
||||
--use_tpu=True \
|
||||
--num_host=${NUM_HOST} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=1000 \
|
||||
--save_steps=10000 \
|
||||
--do_train=True \
|
||||
--do_eval=False \
|
||||
${@:2}
|
||||
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/enwik8-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/enwik8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_host=${TEST_NUM_HOST} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--use_tpu=True \
|
||||
--do_train=False \
|
||||
--do_eval_only=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,110 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=../data/one-billion-words/
|
||||
|
||||
# Model
|
||||
DIV_VAL=4
|
||||
N_LAYER=18
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=8
|
||||
D_HEAD=128
|
||||
D_INNER=4096
|
||||
|
||||
# Training
|
||||
TGT_LEN=256
|
||||
MEM_LEN=256
|
||||
|
||||
BSZ=256
|
||||
NUM_CORE=4
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=32
|
||||
TEST_MEM_LEN=128
|
||||
TEST_CLAMP_LEN=-1
|
||||
|
||||
TEST_BSZ=16
|
||||
TEST_NUM_CORE=1
|
||||
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=lm1b \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${BSZ} \
|
||||
--per_host_valid_bsz=${BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=lm1b \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-lm1b \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=False \
|
||||
--proj_same_dim=False \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.1 \
|
||||
--dropatt=0.0 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=0 \
|
||||
--train_steps=400000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=200 \
|
||||
--save_steps=4000 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-lm1b \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=False \
|
||||
--proj_same_dim=False \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,136 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Path
|
||||
LOCAL_DIR=../data/one-billion-words/
|
||||
GSDATA=
|
||||
GSEXP=
|
||||
|
||||
# TPU setting
|
||||
NUM_HOST=32
|
||||
NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
TEST_NUM_HOST=1
|
||||
TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
# Model
|
||||
DIV_VAL=4
|
||||
N_LAYER=24
|
||||
D_MODEL=1280
|
||||
D_EMBED=1280
|
||||
N_HEAD=16
|
||||
D_HEAD=80
|
||||
D_INNER=8192
|
||||
|
||||
# Training
|
||||
TGT_LEN=32
|
||||
MEM_LEN=32
|
||||
TRAIN_BSZ=512
|
||||
VALID_BSZ=512
|
||||
TRAIN_BSZ_PER_HOST=$((TRAIN_BSZ / NUM_HOST))
|
||||
VALID_BSZ_PER_HOST=$((VALID_BSZ / NUM_HOST))
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=32
|
||||
TEST_MEM_LEN=128
|
||||
TEST_CLAMP_LEN=-1
|
||||
TEST_BSZ=8
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=lm1b \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${TRAIN_BSZ_PER_HOST} \
|
||||
--per_host_valid_bsz=${VALID_BSZ_PER_HOST} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--num_passes=10 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/
|
||||
|
||||
SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/
|
||||
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=lm1b \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/lm1b-tfrecords/
|
||||
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/lm1b-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/lm1b \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=False \
|
||||
--proj_same_dim=False \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.05 \
|
||||
--dropatt=0.05 \
|
||||
--init_std=0.005 \
|
||||
--learning_rate=0.0001 \
|
||||
--warmup_steps=30000 \
|
||||
--train_steps=1200000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${TRAIN_BSZ} \
|
||||
--num_hosts=${NUM_HOST} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=1000 \
|
||||
--save_steps=10000 \
|
||||
--use_tpu=True \
|
||||
--do_eval=False \
|
||||
${@:2}
|
||||
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/lm1b-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/lm1b \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=False \
|
||||
--proj_same_dim=False \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_host=${TEST_NUM_HOST} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--use_tpu=True \
|
||||
--do_train=False \
|
||||
--do_eval_only=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,102 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=../data/text8/
|
||||
|
||||
# Model
|
||||
N_LAYER=12
|
||||
D_MODEL=512
|
||||
D_EMBED=512
|
||||
N_HEAD=8
|
||||
D_HEAD=64
|
||||
D_INNER=2048
|
||||
|
||||
# Training
|
||||
TGT_LEN=512
|
||||
MEM_LEN=512
|
||||
|
||||
BSZ=24
|
||||
NUM_CORE=4
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=80
|
||||
TEST_MEM_LEN=2100
|
||||
TEST_CLAMP_LEN=820
|
||||
|
||||
TEST_BSZ=10
|
||||
TEST_NUM_CORE=1
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=text8 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${BSZ} \
|
||||
--per_host_valid_bsz=${BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=text8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-text8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.1 \
|
||||
--dropatt=0.0 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=0 \
|
||||
--train_steps=400000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=200 \
|
||||
--save_steps=4000 \
|
||||
--do_train=True \
|
||||
--do_eval=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-text8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Path
|
||||
LOCAL_DIR=../data/text8/
|
||||
GSDATA=
|
||||
GSEXP=
|
||||
|
||||
# TPU setting
|
||||
NUM_HOST=2
|
||||
NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
TEST_NUM_HOST=1
|
||||
TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
# Model
|
||||
N_LAYER=24
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=8
|
||||
D_HEAD=128
|
||||
D_INNER=3072
|
||||
|
||||
# Training
|
||||
TGT_LEN=768
|
||||
MEM_LEN=768
|
||||
TRAIN_BSZ=64
|
||||
VALID_BSZ=64
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=128
|
||||
TEST_MEM_LEN=3800
|
||||
TEST_CLAMP_LEN=1000
|
||||
TEST_BSZ=16
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=text8 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${TRAIN_BSZ} \
|
||||
--per_host_valid_bsz=${VALID_BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--num_passes=10 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/
|
||||
|
||||
SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/
|
||||
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=text8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/text8-tfrecords/
|
||||
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/text8-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/text8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.15 \
|
||||
--dropatt=0.15 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=4000 \
|
||||
--train_steps=400000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${TRAIN_BSZ} \
|
||||
--use_tpu=True \
|
||||
--num_host=${NUM_HOST} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=1000 \
|
||||
--save_steps=10000 \
|
||||
--do_train=True \
|
||||
--do_eval=False \
|
||||
${@:2}
|
||||
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/text8-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/text8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_host=${TEST_NUM_HOST} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--use_tpu=True \
|
||||
--do_train=False \
|
||||
--do_eval_only=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=../data/wikitext-103/
|
||||
|
||||
# Model
|
||||
DIV_VAL=1
|
||||
N_LAYER=16
|
||||
D_MODEL=410
|
||||
D_EMBED=410
|
||||
N_HEAD=10
|
||||
D_HEAD=41
|
||||
D_INNER=2100
|
||||
|
||||
# Training
|
||||
TGT_LEN=150
|
||||
MEM_LEN=150
|
||||
|
||||
BSZ=60
|
||||
NUM_CORE=4
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=64
|
||||
TEST_MEM_LEN=640
|
||||
TEST_CLAMP_LEN=400
|
||||
|
||||
TEST_BSZ=10
|
||||
TEST_NUM_CORE=1
|
||||
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=wt103 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${BSZ} \
|
||||
--per_host_valid_bsz=${BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_ROOT}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False \
|
||||
${@:2}
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-wt103 \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=True \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.1 \
|
||||
--dropatt=0.0 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=0 \
|
||||
--train_steps=400000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=200 \
|
||||
--save_steps=4000 \
|
||||
${@:2}
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_ROOT}/tfrecords \
|
||||
--record_info_dir=${DATA_ROOT}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_ROOT}/corpus-info.json \
|
||||
--model_dir=EXP-wt103 \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=True \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,134 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Path
|
||||
LOCAL_DIR=../data/wikitext-103/
|
||||
GSDATA=
|
||||
GSEXP=
|
||||
|
||||
# TPU setting
|
||||
NUM_HOST=4
|
||||
NUM_CORE=16 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
TEST_NUM_HOST=1
|
||||
TEST_NUM_CORE=8 # TPUv2 -> 8 | TPUv3 -> 16
|
||||
|
||||
# Model
|
||||
DIV_VAL=4
|
||||
N_LAYER=18
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=16
|
||||
D_HEAD=64
|
||||
D_INNER=4096
|
||||
|
||||
# Training
|
||||
TGT_LEN=384
|
||||
MEM_LEN=384
|
||||
TRAIN_BSZ=128
|
||||
VALID_BSZ=128
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=128
|
||||
TEST_MEM_LEN=1600
|
||||
TEST_CLAMP_LEN=1000
|
||||
TEST_BSZ=8
|
||||
|
||||
if [[ $1 == 'train_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=wt103 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--per_host_train_bsz=${TRAIN_BSZ} \
|
||||
--per_host_valid_bsz=${VALID_BSZ} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--num_passes=10 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=train.bsz-${TRAIN_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/
|
||||
|
||||
SRC_PATTERN=valid.bsz-${VALID_BSZ}.tlen-${TGT_LEN}.core-${NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/
|
||||
|
||||
elif [[ $1 == 'test_data' ]]; then
|
||||
python data_utils.py \
|
||||
--data_dir=${LOCAL_DIR}/ \
|
||||
--dataset=wt103 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=True \
|
||||
${@:2}
|
||||
|
||||
SRC_PATTERN=test.bsz-${TEST_BSZ}.tlen-${TEST_TGT_LEN}.core-${TEST_NUM_CORE}*
|
||||
gsutil cp ${LOCAL_DIR}/tfrecords/${SRC_PATTERN} ${GSDATA}/wt103-tfrecords/
|
||||
|
||||
elif [[ $1 == 'train' ]]; then
|
||||
echo 'Run training...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/wt103-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/wt103 \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=True \
|
||||
--proj_same_dim=True \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.2 \
|
||||
--dropatt=0.2 \
|
||||
--init_std=0.005 \
|
||||
--learning_rate=0.00025 \
|
||||
--warmup_steps=16000 \
|
||||
--train_steps=4000000 \
|
||||
--tgt_len=${TGT_LEN} \
|
||||
--mem_len=${MEM_LEN} \
|
||||
--train_batch_size=${TRAIN_BSZ} \
|
||||
--num_hosts=${NUM_HOST} \
|
||||
--num_core_per_host=${NUM_CORE} \
|
||||
--iterations=1000 \
|
||||
--save_steps=10000 \
|
||||
--use_tpu=True \
|
||||
--do_eval=False \
|
||||
${@:2}
|
||||
|
||||
elif [[ $1 == 'eval' ]]; then
|
||||
echo 'Run evaluation...'
|
||||
python train.py \
|
||||
--data_dir=${GSDATA}/wt103-tfrecords \
|
||||
--record_info_dir=${LOCAL_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${LOCAL_DIR}/corpus-info.json \
|
||||
--model_dir=${GSEXP}/wt103 \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=True \
|
||||
--proj_same_dim=True \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_host=${TEST_NUM_HOST} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--use_tpu=True \
|
||||
--do_train=False \
|
||||
--do_eval_only=True \
|
||||
--eval_split=test \
|
||||
${@:2}
|
||||
|
||||
else
|
||||
echo 'unknown argment 1'
|
||||
fi
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
URL=http://curtis.ml.cmu.edu/datasets/pretrained_xl
|
||||
|
||||
DATA_ROOT=./
|
||||
|
||||
function download () {
|
||||
fileurl=${1}
|
||||
filename=${fileurl##*/}
|
||||
if [ ! -f ${filename} ]; then
|
||||
echo ">>> Download '${filename}' from '${fileurl}'."
|
||||
wget --quiet ${fileurl}
|
||||
else
|
||||
echo "*** File '${filename}' exists. Skip."
|
||||
fi
|
||||
}
|
||||
|
||||
cd $DATA_ROOT
|
||||
mkdir -p pretrained_xl && cd pretrained_xl
|
||||
|
||||
# enwik8
|
||||
mkdir -p tf_enwik8 && cd tf_enwik8
|
||||
|
||||
mkdir -p data && cd data
|
||||
download ${URL}/tf_enwiki8/data/cache.pkl
|
||||
download ${URL}/tf_enwiki8/data/corpus-info.json
|
||||
cd ..
|
||||
|
||||
mkdir -p model && cd model
|
||||
download ${URL}/tf_enwiki8/model/checkpoint
|
||||
download ${URL}/tf_enwiki8/model/model.ckpt-0.data-00000-of-00001
|
||||
download ${URL}/tf_enwiki8/model/model.ckpt-0.index
|
||||
download ${URL}/tf_enwiki8/model/model.ckpt-0.meta
|
||||
cd ..
|
||||
|
||||
cd ..
|
||||
|
||||
# text8
|
||||
mkdir -p tf_text8 && cd tf_text8
|
||||
|
||||
mkdir -p data && cd data
|
||||
download ${URL}/tf_text8/data/cache.pkl
|
||||
download ${URL}/tf_text8/data/corpus-info.json
|
||||
cd ..
|
||||
|
||||
mkdir -p model && cd model
|
||||
download ${URL}/tf_text8/model/checkpoint
|
||||
download ${URL}/tf_text8/model/model.ckpt-0.data-00000-of-00001
|
||||
download ${URL}/tf_text8/model/model.ckpt-0.index
|
||||
download ${URL}/tf_text8/model/model.ckpt-0.meta
|
||||
cd ..
|
||||
|
||||
cd ..
|
||||
|
||||
# wt103
|
||||
mkdir -p tf_wt103 && cd tf_wt103
|
||||
|
||||
mkdir -p data && cd data
|
||||
download ${URL}/tf_wt103/data/cache.pkl
|
||||
download ${URL}/tf_wt103/data/corpus-info.json
|
||||
cd ..
|
||||
|
||||
mkdir -p model && cd model
|
||||
download ${URL}/tf_wt103/model/checkpoint
|
||||
download ${URL}/tf_wt103/model/model.ckpt-0.data-00000-of-00001
|
||||
download ${URL}/tf_wt103/model/model.ckpt-0.index
|
||||
download ${URL}/tf_wt103/model/model.ckpt-0.meta
|
||||
cd ..
|
||||
|
||||
cd ..
|
||||
|
||||
# lm1b
|
||||
mkdir -p tf_lm1b && cd tf_lm1b
|
||||
|
||||
mkdir -p data && cd data
|
||||
download ${URL}/tf_lm1b/data/cache.pkl
|
||||
download ${URL}/tf_lm1b/data/corpus-info.json
|
||||
cd ..
|
||||
|
||||
mkdir -p model && cd model
|
||||
download ${URL}/tf_lm1b/model/checkpoint
|
||||
download ${URL}/tf_lm1b/model/model.ckpt-1191000.data-00000-of-00001
|
||||
download ${URL}/tf_lm1b/model/model.ckpt-1191000.index
|
||||
download ${URL}/tf_lm1b/model/model.ckpt-1191000.meta
|
||||
cd ..
|
||||
|
||||
cd ..
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=./
|
||||
DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/data
|
||||
MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_enwik8/model
|
||||
|
||||
# Model
|
||||
N_LAYER=24
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=8
|
||||
D_HEAD=128
|
||||
D_INNER=3072
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=128
|
||||
TEST_MEM_LEN=3800
|
||||
TEST_CLAMP_LEN=1000
|
||||
|
||||
TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0
|
||||
TEST_BSZ=16
|
||||
TEST_NUM_CORE=2
|
||||
|
||||
|
||||
echo 'Preprocess test set...'
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_DIR}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False
|
||||
|
||||
echo 'Run evaluation on test set...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_DIR}/tfrecords \
|
||||
--record_info_dir=${DATA_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_DIR}/corpus-info.json \
|
||||
--eval_ckpt_path=${TEST_CKPT_PATH} \
|
||||
--model_dir=EXP-enwik8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=./
|
||||
DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_lm1b/data
|
||||
MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_lm1b/model
|
||||
|
||||
# Model
|
||||
DIV_VAL=4
|
||||
N_LAYER=24
|
||||
D_MODEL=1280
|
||||
D_EMBED=1280
|
||||
N_HEAD=16
|
||||
D_HEAD=80
|
||||
D_INNER=8192
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=32
|
||||
TEST_MEM_LEN=128
|
||||
TEST_CLAMP_LEN=-1
|
||||
|
||||
TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-1191000
|
||||
TEST_BSZ=16
|
||||
TEST_NUM_CORE=1
|
||||
|
||||
|
||||
echo 'Preprocess test set...'
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_DIR}/ \
|
||||
--dataset=lm1b \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False
|
||||
|
||||
echo 'Run evaluation on test set...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_DIR}/tfrecords \
|
||||
--record_info_dir=${DATA_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_DIR}/corpus-info.json \
|
||||
--eval_ckpt_path=${TEST_CKPT_PATH} \
|
||||
--model_dir=EXP-lm1b \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=False \
|
||||
--proj_same_dim=False \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=./
|
||||
DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_text8/data
|
||||
MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_text8/model
|
||||
|
||||
# Model
|
||||
N_LAYER=24
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=8
|
||||
D_HEAD=128
|
||||
D_INNER=3072
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=128
|
||||
TEST_MEM_LEN=3800
|
||||
TEST_CLAMP_LEN=1000
|
||||
|
||||
TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0
|
||||
TEST_BSZ=16
|
||||
TEST_NUM_CORE=2
|
||||
|
||||
|
||||
echo 'Preprocess test set...'
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_DIR}/ \
|
||||
--dataset=text8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False
|
||||
|
||||
echo 'Run evaluation on test set...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_DIR}/tfrecords \
|
||||
--record_info_dir=${DATA_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_DIR}/corpus-info.json \
|
||||
--eval_ckpt_path=${TEST_CKPT_PATH} \
|
||||
--model_dir=EXP-text8 \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test
|
||||
|
|
@ -1,71 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Data
|
||||
DATA_ROOT=./
|
||||
DATA_DIR=${DATA_ROOT}/pretrained_xl/tf_wt103/data
|
||||
MODEL_DIR=${DATA_ROOT}/pretrained_xl/tf_wt103/model
|
||||
|
||||
# Model
|
||||
DIV_VAL=4
|
||||
N_LAYER=18
|
||||
D_MODEL=1024
|
||||
D_EMBED=1024
|
||||
N_HEAD=16
|
||||
D_HEAD=64
|
||||
D_INNER=4096
|
||||
|
||||
# Training
|
||||
TGT_LEN=256
|
||||
MEM_LEN=256
|
||||
|
||||
BSZ=16
|
||||
NUM_CORE=2
|
||||
|
||||
# Testing
|
||||
TEST_TGT_LEN=128
|
||||
TEST_MEM_LEN=1600
|
||||
TEST_CLAMP_LEN=1000
|
||||
|
||||
TEST_CKPT_PATH=${MODEL_DIR}/model.ckpt-0
|
||||
TEST_BSZ=16
|
||||
TEST_NUM_CORE=1
|
||||
|
||||
|
||||
echo 'Preprocess test set...'
|
||||
python data_utils.py \
|
||||
--data_dir=${DATA_DIR}/ \
|
||||
--dataset=enwik8 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--per_host_test_bsz=${TEST_BSZ} \
|
||||
--num_passes=1 \
|
||||
--use_tpu=False
|
||||
|
||||
|
||||
echo 'Run evaluation on test set...'
|
||||
python train_gpu.py \
|
||||
--data_dir=${DATA_DIR}/tfrecords \
|
||||
--record_info_dir=${DATA_DIR}/tfrecords/ \
|
||||
--corpus_info_path=${DATA_DIR}/corpus-info.json \
|
||||
--eval_ckpt_path=${TEST_CKPT_PATH} \
|
||||
--model_dir=EXP-wt103 \
|
||||
--div_val=${DIV_VAL} \
|
||||
--untie_r=True \
|
||||
--proj_share_all_but_first=True \
|
||||
--n_layer=${N_LAYER} \
|
||||
--d_model=${D_MODEL} \
|
||||
--d_embed=${D_EMBED} \
|
||||
--n_head=${N_HEAD} \
|
||||
--d_head=${D_HEAD} \
|
||||
--d_inner=${D_INNER} \
|
||||
--dropout=0.0 \
|
||||
--dropatt=0.0 \
|
||||
--tgt_len=${TEST_TGT_LEN} \
|
||||
--mem_len=${TEST_MEM_LEN} \
|
||||
--clamp_len=${TEST_CLAMP_LEN} \
|
||||
--same_length=True \
|
||||
--eval_batch_size=${TEST_BSZ} \
|
||||
--num_core_per_host=${TEST_NUM_CORE} \
|
||||
--do_train=False \
|
||||
--do_eval=True \
|
||||
--eval_split=test
|
||||
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,462 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import absl.logging as _logging # pylint: disable=unused-import
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow.gfile import Exists as exists
|
||||
import model
|
||||
import data_utils
|
||||
import tpu_estimator
|
||||
|
||||
import numpy as np
|
||||
from time import sleep
|
||||
|
||||
|
||||
# TPU parameters
|
||||
flags.DEFINE_string("master", default=None,
|
||||
help="master")
|
||||
flags.DEFINE_string("tpu", default=None,
|
||||
help="The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 url.")
|
||||
flags.DEFINE_string("gcp_project", default=None,
|
||||
help="Project name for the Cloud TPU-enabled project. If not specified, "
|
||||
"we will attempt to automatically detect the GCE project from metadata.")
|
||||
flags.DEFINE_string("tpu_zone",default=None,
|
||||
help="GCE zone where the Cloud TPU is located in. If not specified, we "
|
||||
"will attempt to automatically detect the GCE project from metadata.")
|
||||
flags.DEFINE_bool("use_tpu", default=True,
|
||||
help="Use TPUs rather than plain CPUs.")
|
||||
flags.DEFINE_integer("num_hosts", default=1,
|
||||
help="number of TPU hosts")
|
||||
flags.DEFINE_integer("num_core_per_host", default=8,
|
||||
help="number of cores per host")
|
||||
|
||||
# Experiment (data/checkpoint/directory) parameters
|
||||
flags.DEFINE_string("data_dir", default="",
|
||||
help="Path to tf-records directory.")
|
||||
flags.DEFINE_string("record_info_dir", default="",
|
||||
help="Path to local directory containing filenames.txt.")
|
||||
flags.DEFINE_string("corpus_info_path", default="",
|
||||
help="Path to corpus-info.json file.")
|
||||
flags.DEFINE_string("model_dir", default=None,
|
||||
help="Estimator model_dir.")
|
||||
flags.DEFINE_bool("do_eval", default=False,
|
||||
help="Whether to run eval on the dev set.")
|
||||
flags.DEFINE_bool("track_mean", default=True,
|
||||
help="Trace mean loss during training.")
|
||||
flags.DEFINE_string("eval_ckpt_path", None,
|
||||
help="Checkpoint path for evaluation."
|
||||
"If set, model_dir will be ignored."
|
||||
"If unset, will use the latest ckpt in model_dir.")
|
||||
flags.DEFINE_string("warm_start_path", None,
|
||||
help="Checkpoint path for warm start."
|
||||
"If set, will clear Adam states."
|
||||
"Note that the new model_dir should be different"
|
||||
" from warm_start_path.")
|
||||
|
||||
# Optimization paramenters
|
||||
flags.DEFINE_float("learning_rate", default=2.5e-4,
|
||||
help="Maximum learning rate.")
|
||||
flags.DEFINE_float("clip", default=0.25,
|
||||
help="Gradient clipping value.")
|
||||
# for cosine decay
|
||||
flags.DEFINE_float("min_lr_ratio", default=0.01,
|
||||
help="Minimum ratio learning rate.")
|
||||
flags.DEFINE_integer("warmup_steps", default=0,
|
||||
help="Number of steps for linear lr warmup.")
|
||||
|
||||
# Training parameters
|
||||
flags.DEFINE_integer("train_batch_size", default=60,
|
||||
help="Size of train batch.")
|
||||
flags.DEFINE_integer("eval_batch_size", default=60,
|
||||
help="Size of valid batch.")
|
||||
flags.DEFINE_integer("train_steps", default=100000,
|
||||
help="Total number of training steps.")
|
||||
flags.DEFINE_integer("iterations", default=500,
|
||||
help="Number of iterations per repeat loop.")
|
||||
flags.DEFINE_integer("save_steps", default=10000,
|
||||
help="number of steps for model checkpointing.")
|
||||
|
||||
# Evaluation parameters
|
||||
flags.DEFINE_integer("max_eval_batch", default=-1,
|
||||
help="Set -1 to turn off. Only used in test mode.")
|
||||
flags.DEFINE_bool("do_eval_only", default=False,
|
||||
help="Run evaluation only.")
|
||||
flags.DEFINE_integer("start_eval_steps", default=10000,
|
||||
help="Which checkpoint to start with in `do_eval_only` mode.")
|
||||
flags.DEFINE_string("eval_split", "valid",
|
||||
help="Which data split to evaluate.")
|
||||
|
||||
# Model paramenters
|
||||
flags.DEFINE_integer("tgt_len", default=70,
|
||||
help="Number of steps to predict")
|
||||
flags.DEFINE_integer("mem_len", default=70,
|
||||
help="Number of steps to cache")
|
||||
flags.DEFINE_bool("same_length", default=False,
|
||||
help="Same length attention")
|
||||
flags.DEFINE_integer("clamp_len", default=-1,
|
||||
help="Clamp length")
|
||||
|
||||
flags.DEFINE_integer("n_layer", default=6,
|
||||
help="Number of layers.")
|
||||
flags.DEFINE_integer("d_model", default=500,
|
||||
help="Dimension of the model.")
|
||||
flags.DEFINE_integer("d_embed", default=500,
|
||||
help="Dimension of the embeddings.")
|
||||
flags.DEFINE_integer("n_head", default=10,
|
||||
help="Number of attention heads.")
|
||||
flags.DEFINE_integer("d_head", default=50,
|
||||
help="Dimension of each attention head.")
|
||||
flags.DEFINE_integer("d_inner", default=1000,
|
||||
help="Dimension of inner hidden size in positionwise feed-forward.")
|
||||
flags.DEFINE_float("dropout", default=0.1,
|
||||
help="Dropout rate.")
|
||||
flags.DEFINE_float("dropatt", default=0.1,
|
||||
help="Attention dropout rate.")
|
||||
flags.DEFINE_bool("untie_r", default=False,
|
||||
help="untie r_w_bias and r_r_bias")
|
||||
|
||||
# Adaptive Softmax / Embedding
|
||||
flags.DEFINE_bool("tie_weight", default=True,
|
||||
help="Tie embedding and softmax weight.")
|
||||
flags.DEFINE_integer("div_val", default=1,
|
||||
help="Divide the embedding size by this val for each bin")
|
||||
flags.DEFINE_bool("proj_share_all_but_first", default=False,
|
||||
help="True to share all but first projs, False not to share.")
|
||||
flags.DEFINE_bool("proj_same_dim", default=True,
|
||||
help="Project the bin with the same dimension.")
|
||||
|
||||
# Parameter initialization
|
||||
flags.DEFINE_enum("init", default="normal",
|
||||
enum_values=["normal", "uniform"],
|
||||
help="Initialization method.")
|
||||
flags.DEFINE_float("init_std", default=0.02,
|
||||
help="Initialization std when init is normal.")
|
||||
flags.DEFINE_float("proj_init_std", default=0.01,
|
||||
help="Initialization std for embedding projection.")
|
||||
flags.DEFINE_float("init_range", default=0.1,
|
||||
help="Initialization std when init is uniform.")
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def metric_fn(loss):
|
||||
"""Evaluation metric Fn which runs on CPU."""
|
||||
perplexity = tf.exp(tf.reduce_mean(loss))
|
||||
bpc = tf.reduce_mean(loss) / tf.constant(math.log(2))
|
||||
return {
|
||||
"perplexity": tf.metrics.mean(perplexity),
|
||||
"bpc": tf.metrics.mean(bpc),
|
||||
}
|
||||
|
||||
|
||||
def get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes):
|
||||
def model_fn(features, labels, mode, params):
|
||||
is_training = (mode == tf.estimator.ModeKeys.TRAIN)
|
||||
|
||||
|
||||
batch_size = params["batch_size"]
|
||||
|
||||
mems = params["cache"]
|
||||
inp = tf.transpose(features["inputs"], [1, 0])
|
||||
tgt = tf.transpose(features["labels"], [1, 0])
|
||||
|
||||
bin_sizes = train_bin_sizes if is_training else eval_bin_sizes
|
||||
if bin_sizes:
|
||||
inp_perms = [tf.transpose(features["inp_mask"], [1, 0])]
|
||||
tgt_perms = [tf.transpose(features["tgt_mask"], [1, 0])]
|
||||
|
||||
head_tgt = tf.transpose(features["head_labels"], [1, 0])
|
||||
|
||||
for b in range(len(bin_sizes)):
|
||||
inp_perm = tf.transpose(features["inp_perm_{}".format(b)], [1, 0, 2])
|
||||
tgt_perm = tf.transpose(features["tgt_perm_{}".format(b)], [1, 0, 2])
|
||||
|
||||
inp_perms.append(inp_perm)
|
||||
tgt_perms.append(tgt_perm)
|
||||
else:
|
||||
inp_perms, tgt_perms, head_tgt = None, None, None
|
||||
|
||||
if FLAGS.init == "uniform":
|
||||
initializer = tf.initializers.random_uniform(
|
||||
minval=-FLAGS.init_range,
|
||||
maxval=FLAGS.init_range,
|
||||
seed=None)
|
||||
elif FLAGS.init == "normal":
|
||||
initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.init_std,
|
||||
seed=None)
|
||||
proj_initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.proj_init_std,
|
||||
seed=None)
|
||||
|
||||
tie_projs = [False for _ in range(len(cutoffs) + 1)]
|
||||
if FLAGS.proj_share_all_but_first:
|
||||
for i in range(1, len(tie_projs)):
|
||||
tie_projs[i] = True
|
||||
|
||||
tf.logging.info("Vocab size : {}".format(n_token))
|
||||
tf.logging.info("Batch size : {}".format(batch_size))
|
||||
|
||||
loss, new_mems = model.transformer(
|
||||
dec_inp=inp,
|
||||
target=tgt,
|
||||
mems=mems,
|
||||
n_token=n_token,
|
||||
n_layer=FLAGS.n_layer,
|
||||
d_model=FLAGS.d_model,
|
||||
d_embed=FLAGS.d_embed,
|
||||
n_head=FLAGS.n_head,
|
||||
d_head=FLAGS.d_head,
|
||||
d_inner=FLAGS.d_inner,
|
||||
dropout=FLAGS.dropout,
|
||||
dropatt=FLAGS.dropatt,
|
||||
initializer=initializer,
|
||||
is_training=is_training,
|
||||
mem_len=FLAGS.mem_len,
|
||||
cutoffs=cutoffs,
|
||||
div_val=FLAGS.div_val,
|
||||
tie_projs=tie_projs,
|
||||
input_perms=inp_perms,
|
||||
target_perms=tgt_perms,
|
||||
head_target=head_tgt,
|
||||
same_length=FLAGS.same_length,
|
||||
clamp_len=FLAGS.clamp_len,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
untie_r=FLAGS.untie_r,
|
||||
proj_same_dim=FLAGS.proj_same_dim)
|
||||
|
||||
total_loss = tf.reduce_mean(loss)
|
||||
|
||||
if mode == tf.estimator.ModeKeys.EVAL:
|
||||
if FLAGS.use_tpu:
|
||||
with tf.colocate_with(total_loss):
|
||||
total_loss = tf.contrib.tpu.cross_replica_sum(total_loss) \
|
||||
/ FLAGS.num_hosts / FLAGS.num_core_per_host
|
||||
metric_loss = tf.tile(tf.reshape(total_loss, [1, 1]), [batch_size, 1])
|
||||
eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode,
|
||||
loss=total_loss,
|
||||
eval_metrics=(metric_fn, [metric_loss]))
|
||||
|
||||
eval_spec.cache = new_mems
|
||||
|
||||
return eval_spec
|
||||
|
||||
# Configuring the optimization step.
|
||||
global_step = tf.train.get_global_step()
|
||||
|
||||
# increase the learning rate linearly
|
||||
if FLAGS.warmup_steps > 0:
|
||||
warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
|
||||
* FLAGS.learning_rate
|
||||
else:
|
||||
warmup_lr = 0.0
|
||||
|
||||
# number of parameters
|
||||
num_params = np.sum([np.prod(v.shape) for v in tf.trainable_variables()])
|
||||
tf.logging.info("#params: {}".format(num_params))
|
||||
|
||||
# format_str = '{{:<{0}s}}\t{{}}'.format(
|
||||
# max([len(v.name) for v in tf.trainable_variables()]))
|
||||
# for v in tf.trainable_variables():
|
||||
# tf.logging.info(format_str.format(v.name, v.get_shape()))
|
||||
|
||||
|
||||
# decay the learning rate using the cosine schedule
|
||||
decay_lr = tf.train.cosine_decay(
|
||||
FLAGS.learning_rate,
|
||||
global_step=global_step-FLAGS.warmup_steps,
|
||||
decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
|
||||
alpha=FLAGS.min_lr_ratio)
|
||||
|
||||
learning_rate = tf.where(global_step < FLAGS.warmup_steps,
|
||||
warmup_lr, decay_lr)
|
||||
|
||||
if FLAGS.use_tpu:
|
||||
optimizer = tf.contrib.tpu.CrossShardOptimizer(
|
||||
tf.train.AdamOptimizer(learning_rate=learning_rate))
|
||||
#GradientDescentOptimizer
|
||||
else:
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
|
||||
|
||||
grads_and_vars = optimizer.compute_gradients(total_loss)
|
||||
gradients, variables = zip(*grads_and_vars)
|
||||
clipped, _ = tf.clip_by_global_norm(gradients, FLAGS.clip)
|
||||
train_op = optimizer.apply_gradients(
|
||||
zip(clipped, variables), global_step=tf.train.get_global_step())
|
||||
|
||||
# Constucting TPUEstimatorSpec with cache.
|
||||
train_spec = tf.contrib.tpu.TPUEstimatorSpec(
|
||||
mode=mode, loss=total_loss, train_op=train_op)
|
||||
|
||||
if FLAGS.mem_len < FLAGS.tgt_len:
|
||||
new_mems = [new_mems[: FLAGS.mem_len] for mem_t in new_mems]
|
||||
train_spec.cache = new_mems
|
||||
|
||||
return train_spec
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def get_cache_fn(mem_len):
|
||||
|
||||
def cache_fn(batch_size):
|
||||
mems = []
|
||||
for l in xrange(FLAGS.n_layer):
|
||||
if mem_len > 0:
|
||||
mems.append(
|
||||
tf.zeros([mem_len, batch_size, FLAGS.d_model], dtype=tf.float32))
|
||||
else:
|
||||
mems.append(tf.zeros([mem_len], dtype=tf.float32))
|
||||
|
||||
return mems
|
||||
|
||||
return cache_fn
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv # Unused
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
# Get corpus info
|
||||
corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
|
||||
n_token = corpus_info["vocab_size"]
|
||||
cutoffs = corpus_info["cutoffs"][1:-1]
|
||||
|
||||
if FLAGS.save_steps == 0:
|
||||
FLAGS.save_steps = None
|
||||
|
||||
if not FLAGS.do_eval_only:
|
||||
# Get train input function
|
||||
train_input_fn, train_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split="train",
|
||||
per_host_bsz=FLAGS.train_batch_size // FLAGS.num_hosts,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=FLAGS.num_hosts,
|
||||
use_tpu=FLAGS.use_tpu)
|
||||
train_bin_sizes = train_record_info["bin_sizes"]
|
||||
num_train_batch = train_record_info["num_batch"]
|
||||
|
||||
# Get train cache function
|
||||
train_cache_fn = get_cache_fn(FLAGS.mem_len)
|
||||
else:
|
||||
train_bin_sizes = []
|
||||
num_train_batch = None
|
||||
train_cache_fn = None
|
||||
|
||||
if FLAGS.do_eval or FLAGS.do_eval_only:
|
||||
assert FLAGS.num_hosts == 1
|
||||
# Get eval input function
|
||||
eval_input_fn, eval_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split=FLAGS.eval_split,
|
||||
per_host_bsz=FLAGS.eval_batch_size // FLAGS.num_hosts,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=FLAGS.num_hosts,
|
||||
use_tpu=FLAGS.use_tpu)
|
||||
eval_bin_sizes = eval_record_info["bin_sizes"]
|
||||
num_eval_batch = eval_record_info["num_batch"]
|
||||
|
||||
if FLAGS.max_eval_batch > 0:
|
||||
num_eval_batch = min(FLAGS.max_eval_batch, num_eval_batch)
|
||||
|
||||
# Get eval cache function
|
||||
eval_cache_fn = get_cache_fn(FLAGS.mem_len)
|
||||
model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, eval_bin_sizes)
|
||||
else:
|
||||
eval_cache_fn = None
|
||||
model_fn = get_model_fn(n_token, cutoffs, train_bin_sizes, [])
|
||||
|
||||
##### Create estimator
|
||||
# TPU Configuration
|
||||
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
|
||||
FLAGS.tpu, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)
|
||||
|
||||
per_host_input = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
|
||||
run_config = tf.contrib.tpu.RunConfig(
|
||||
cluster=tpu_cluster_resolver,
|
||||
model_dir=FLAGS.model_dir,
|
||||
session_config=tf.ConfigProto(
|
||||
allow_soft_placement=True, log_device_placement=True),
|
||||
tpu_config=tf.contrib.tpu.TPUConfig(
|
||||
iterations_per_loop=FLAGS.iterations,
|
||||
num_shards=FLAGS.num_core_per_host * FLAGS.num_hosts,
|
||||
per_host_input_for_training=per_host_input),
|
||||
keep_checkpoint_max=100000, # effectively save all checkpoints
|
||||
save_checkpoints_secs=None,
|
||||
save_checkpoints_steps=FLAGS.save_steps
|
||||
)
|
||||
|
||||
# warm start
|
||||
warm_start_from = None
|
||||
if FLAGS.warm_start_path is not None:
|
||||
warm_start_from = tf.estimator.WarmStartSettings(
|
||||
ckpt_to_initialize_from=FLAGS.warm_start_path)
|
||||
|
||||
# TPU Estimator
|
||||
estimator = tpu_estimator.TPUEstimator(
|
||||
model_fn=model_fn,
|
||||
train_cache_fn=train_cache_fn,
|
||||
eval_cache_fn=eval_cache_fn,
|
||||
use_tpu=FLAGS.use_tpu,
|
||||
config=run_config,
|
||||
params={"data_dir":FLAGS.data_dir, "track_mean":FLAGS.track_mean},
|
||||
train_batch_size=FLAGS.train_batch_size,
|
||||
eval_batch_size=FLAGS.eval_batch_size,
|
||||
warm_start_from=warm_start_from)
|
||||
|
||||
if FLAGS.do_eval_only:
|
||||
if FLAGS.eval_ckpt_path is not None:
|
||||
ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch,
|
||||
checkpoint_path=FLAGS.eval_ckpt_path)
|
||||
tf.logging.info("=" * 200)
|
||||
log_str = "Eval results | "
|
||||
for key, val in ret.items():
|
||||
log_str += "{} {} | ".format(key, val)
|
||||
tf.logging.info(log_str)
|
||||
tf.logging.info("=" * 200)
|
||||
else:
|
||||
ckpt_state = tf.train.get_checkpoint_state(FLAGS.model_dir)
|
||||
eval_results = []
|
||||
for eval_checkpoint in ckpt_state.all_model_checkpoint_paths:
|
||||
if not exists(eval_checkpoint + ".index"): continue
|
||||
global_step = int(eval_checkpoint.split("-")[-1])
|
||||
if global_step < FLAGS.start_eval_steps or global_step > FLAGS.train_steps:
|
||||
continue
|
||||
ret = estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch,
|
||||
checkpoint_path=eval_checkpoint)
|
||||
eval_results.append(ret)
|
||||
|
||||
eval_results.sort(key = lambda x: x["perplexity"])
|
||||
|
||||
tf.logging.info("=" * 200)
|
||||
log_str = "Best results | "
|
||||
for key, val in eval_results[0].items():
|
||||
log_str += "{} {} | ".format(key, val)
|
||||
tf.logging.info(log_str)
|
||||
tf.logging.info("=" * 200)
|
||||
else:
|
||||
if not FLAGS.do_eval:
|
||||
estimator.train(input_fn=train_input_fn, steps=FLAGS.train_steps)
|
||||
else:
|
||||
for step in range(0, FLAGS.train_steps, num_train_batch):
|
||||
train_steps = min(FLAGS.train_steps - step, num_train_batch)
|
||||
estimator.train(input_fn=train_input_fn, steps=train_steps)
|
||||
estimator.evaluate(input_fn=eval_input_fn, steps=num_eval_batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
|
|
@ -1,475 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
|
||||
from absl import flags
|
||||
import absl.logging as _logging # pylint: disable=unused-import
|
||||
|
||||
import tensorflow as tf
|
||||
import model
|
||||
import data_utils
|
||||
|
||||
from gpu_utils import assign_to_gpu, average_grads_and_vars
|
||||
|
||||
import numpy as np
|
||||
|
||||
# GPU config
|
||||
flags.DEFINE_integer("num_hosts", default=1,
|
||||
help="Number of TPU hosts")
|
||||
flags.DEFINE_integer("num_core_per_host", default=8,
|
||||
help="Number of cores per host")
|
||||
|
||||
# Experiment (data/checkpoint/directory) config
|
||||
flags.DEFINE_string("data_dir", default="",
|
||||
help="Path to tf-records directory.")
|
||||
flags.DEFINE_string("record_info_dir", default="",
|
||||
help="Path to local directory containing filenames.txt.")
|
||||
flags.DEFINE_string("corpus_info_path", default="",
|
||||
help="Path to corpus-info.json file.")
|
||||
flags.DEFINE_string("model_dir", default=None,
|
||||
help="Estimator model_dir.")
|
||||
flags.DEFINE_bool("do_train", default=True,
|
||||
help="Whether to run training.")
|
||||
flags.DEFINE_bool("do_eval", default=False,
|
||||
help="Whether to run eval on the dev set.")
|
||||
flags.DEFINE_string("eval_ckpt_path", None,
|
||||
help="Checkpoint path for do_test evaluation."
|
||||
"If set, model_dir will be ignored."
|
||||
"If unset, will use the latest ckpt in model_dir.")
|
||||
flags.DEFINE_string("warm_start_path", None,
|
||||
help="Checkpoint path for warm start."
|
||||
"If set, will clear Adam states."
|
||||
"Note that the new model_dir should be different"
|
||||
" from warm_start_path.")
|
||||
|
||||
# Optimization config
|
||||
flags.DEFINE_float("learning_rate", default=2.5e-4,
|
||||
help="Maximum learning rate.")
|
||||
flags.DEFINE_float("clip", default=0.25,
|
||||
help="Gradient clipping value.")
|
||||
# for cosine decay
|
||||
flags.DEFINE_float("min_lr_ratio", default=0.004,
|
||||
help="Minimum ratio learning rate.")
|
||||
flags.DEFINE_integer("warmup_steps", default=0,
|
||||
help="Number of steps for linear lr warmup.")
|
||||
|
||||
# Training config
|
||||
flags.DEFINE_integer("train_batch_size", default=60,
|
||||
help="Size of train batch.")
|
||||
flags.DEFINE_integer("eval_batch_size", default=60,
|
||||
help="Size of valid batch.")
|
||||
flags.DEFINE_integer("train_steps", default=100000,
|
||||
help="Total number of training steps.")
|
||||
flags.DEFINE_integer("iterations", default=500,
|
||||
help="Number of iterations per repeat loop.")
|
||||
flags.DEFINE_integer("save_steps", default=10000,
|
||||
help="number of steps for model checkpointing.")
|
||||
|
||||
# Evaluation config
|
||||
flags.DEFINE_bool("do_test", default=False,
|
||||
help="Run on the test set.")
|
||||
flags.DEFINE_integer("max_eval_batch", default=-1,
|
||||
help="Set -1 to turn off. Only used in test mode.")
|
||||
flags.DEFINE_bool("do_eval_only", default=False,
|
||||
help="Run evaluation only.")
|
||||
flags.DEFINE_integer("start_eval_steps", default=10000,
|
||||
help="Which checkpoint to start with in `do_eval_only` mode.")
|
||||
flags.DEFINE_string("eval_split", "valid",
|
||||
help="Which data split to evaluate.")
|
||||
|
||||
# Model config
|
||||
flags.DEFINE_integer("tgt_len", default=70,
|
||||
help="Number of steps to predict")
|
||||
flags.DEFINE_integer("mem_len", default=70,
|
||||
help="Number of steps to cache")
|
||||
flags.DEFINE_bool("same_length", default=False,
|
||||
help="Same length attention")
|
||||
flags.DEFINE_integer("clamp_len", default=-1,
|
||||
help="Clamp length")
|
||||
|
||||
flags.DEFINE_integer("n_layer", default=6,
|
||||
help="Number of layers.")
|
||||
flags.DEFINE_integer("d_model", default=500,
|
||||
help="Dimension of the model.")
|
||||
flags.DEFINE_integer("d_embed", default=500,
|
||||
help="Dimension of the embeddings.")
|
||||
flags.DEFINE_integer("n_head", default=10,
|
||||
help="Number of attention heads.")
|
||||
flags.DEFINE_integer("d_head", default=50,
|
||||
help="Dimension of each attention head.")
|
||||
flags.DEFINE_integer("d_inner", default=1000,
|
||||
help="Dimension of inner hidden size in positionwise feed-forward.")
|
||||
flags.DEFINE_float("dropout", default=0.1,
|
||||
help="Dropout rate.")
|
||||
flags.DEFINE_float("dropatt", default=0.1,
|
||||
help="Attention dropout rate.")
|
||||
flags.DEFINE_bool("untie_r", default=False,
|
||||
help="untie r_w_bias and r_r_bias")
|
||||
|
||||
# Adaptive Softmax / Embedding
|
||||
flags.DEFINE_bool("tie_weight", default=True,
|
||||
help="Tie embedding and softmax weight.")
|
||||
flags.DEFINE_integer("div_val", default=1,
|
||||
help="Divide the embedding size by this val for each bin")
|
||||
flags.DEFINE_bool("proj_share_all_but_first", default=False,
|
||||
help="True to share all but first projs, False not to share.")
|
||||
flags.DEFINE_bool("proj_same_dim", default=True,
|
||||
help="Project the bin with the same dimension.")
|
||||
|
||||
# Parameter initialization
|
||||
flags.DEFINE_enum("init", default="normal",
|
||||
enum_values=["normal", "uniform"],
|
||||
help="Initialization method.")
|
||||
flags.DEFINE_float("init_std", default=0.02,
|
||||
help="Initialization std when init is normal.")
|
||||
flags.DEFINE_float("proj_init_std", default=0.01,
|
||||
help="Initialization std for embedding projection.")
|
||||
flags.DEFINE_float("init_range", default=0.1,
|
||||
help="Initialization std when init is uniform.")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def get_model_fn(n_token, cutoffs):
|
||||
def model_fn(inp, tgt, mems, is_training):
|
||||
inp = tf.transpose(inp, [1, 0])
|
||||
tgt = tf.transpose(tgt, [1, 0])
|
||||
|
||||
if FLAGS.init == "uniform":
|
||||
initializer = tf.initializers.random_uniform(
|
||||
minval=-FLAGS.init_range,
|
||||
maxval=FLAGS.init_range,
|
||||
seed=None)
|
||||
elif FLAGS.init == "normal":
|
||||
initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.init_std,
|
||||
seed=None)
|
||||
proj_initializer = tf.initializers.random_normal(
|
||||
stddev=FLAGS.proj_init_std,
|
||||
seed=None)
|
||||
|
||||
tie_projs = [False for _ in range(len(cutoffs) + 1)]
|
||||
if FLAGS.proj_share_all_but_first:
|
||||
for i in range(1, len(tie_projs)):
|
||||
tie_projs[i] = True
|
||||
|
||||
loss, new_mems = model.transformer(
|
||||
dec_inp=inp,
|
||||
target=tgt,
|
||||
mems=mems,
|
||||
n_token=n_token,
|
||||
n_layer=FLAGS.n_layer,
|
||||
d_model=FLAGS.d_model,
|
||||
d_embed=FLAGS.d_embed,
|
||||
n_head=FLAGS.n_head,
|
||||
d_head=FLAGS.d_head,
|
||||
d_inner=FLAGS.d_inner,
|
||||
dropout=FLAGS.dropout,
|
||||
dropatt=FLAGS.dropatt,
|
||||
initializer=initializer,
|
||||
proj_initializer=proj_initializer,
|
||||
is_training=is_training,
|
||||
mem_len=FLAGS.mem_len,
|
||||
cutoffs=cutoffs,
|
||||
div_val=FLAGS.div_val,
|
||||
tie_projs=tie_projs,
|
||||
input_perms=None,
|
||||
target_perms=None,
|
||||
head_target=None,
|
||||
same_length=FLAGS.same_length,
|
||||
clamp_len=FLAGS.clamp_len,
|
||||
use_tpu=False,
|
||||
untie_r=FLAGS.untie_r,
|
||||
proj_same_dim=FLAGS.proj_same_dim)
|
||||
|
||||
# number of parameters
|
||||
num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
|
||||
tf.logging.info('#params: {}'.format(num_params))
|
||||
|
||||
# format_str = '{{:<{0}s}}\t{{}}'.format(
|
||||
# max([len(v.name) for v in tf.trainable_variables()]))
|
||||
# for v in tf.trainable_variables():
|
||||
# tf.logging.info(format_str.format(v.name, v.get_shape()))
|
||||
|
||||
if is_training:
|
||||
all_vars = tf.trainable_variables()
|
||||
grads = tf.gradients(loss, all_vars)
|
||||
grads_and_vars = list(zip(grads, all_vars))
|
||||
|
||||
return loss, new_mems, grads_and_vars
|
||||
else:
|
||||
return loss, new_mems
|
||||
|
||||
return model_fn
|
||||
|
||||
|
||||
def single_core_graph(n_token, cutoffs, is_training, inp, tgt, mems):
|
||||
model_fn = get_model_fn(
|
||||
n_token=n_token,
|
||||
cutoffs=cutoffs)
|
||||
|
||||
model_ret = model_fn(
|
||||
inp=inp,
|
||||
tgt=tgt,
|
||||
mems=mems,
|
||||
is_training=is_training)
|
||||
|
||||
return model_ret
|
||||
|
||||
|
||||
def train(n_token, cutoffs, ps_device):
|
||||
##### Get input function and model function
|
||||
train_input_fn, train_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split="train",
|
||||
per_host_bsz=FLAGS.train_batch_size,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=1,
|
||||
use_tpu=False)
|
||||
|
||||
tf.logging.info("num of batches {}".format(train_record_info["num_batch"]))
|
||||
|
||||
##### Create computational graph
|
||||
train_set = train_input_fn({
|
||||
"batch_size": FLAGS.train_batch_size,
|
||||
"data_dir": FLAGS.data_dir})
|
||||
|
||||
input_feed, label_feed = train_set.make_one_shot_iterator().get_next()
|
||||
|
||||
inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
|
||||
labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)
|
||||
|
||||
per_core_bsz = FLAGS.train_batch_size // FLAGS.num_core_per_host
|
||||
|
||||
tower_mems, tower_losses, tower_new_mems, tower_grads_and_vars = [], [], [], []
|
||||
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
reuse = True if i > 0 else None
|
||||
with tf.DEVICE(assign_to_gpu(i, ps_device)), \
|
||||
tf.variable_scope(tf.get_variable_scope(), reuse=reuse):
|
||||
|
||||
mems_i = [tf.placeholder(tf.float32,
|
||||
[FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
|
||||
for _ in range(FLAGS.n_layer)]
|
||||
|
||||
loss_i, new_mems_i, grads_and_vars_i = single_core_graph(
|
||||
n_token=n_token,
|
||||
cutoffs=cutoffs,
|
||||
is_training=True,
|
||||
inp=inputs[i],
|
||||
tgt=labels[i],
|
||||
mems=mems_i)
|
||||
|
||||
tower_mems.append(mems_i)
|
||||
tower_losses.append(loss_i)
|
||||
tower_new_mems.append(new_mems_i)
|
||||
tower_grads_and_vars.append(grads_and_vars_i)
|
||||
|
||||
## average losses and gradients across towers
|
||||
if len(tower_losses) > 1:
|
||||
loss = tf.add_n(tower_losses) / len(tower_losses)
|
||||
grads_and_vars = average_grads_and_vars(tower_grads_and_vars)
|
||||
else:
|
||||
loss = tower_losses[0]
|
||||
grads_and_vars = tower_grads_and_vars[0]
|
||||
grads, all_vars = zip(*grads_and_vars)
|
||||
|
||||
## clip gradient
|
||||
clipped, gnorm = tf.clip_by_global_norm(grads, FLAGS.clip)
|
||||
grads_and_vars = list(zip(clipped, all_vars))
|
||||
|
||||
## configure the optimizer
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
|
||||
# warmup stage: increase the learning rate linearly
|
||||
if FLAGS.warmup_steps > 0:
|
||||
warmup_lr = tf.to_float(global_step) / tf.to_float(FLAGS.warmup_steps) \
|
||||
* FLAGS.learning_rate
|
||||
else:
|
||||
warmup_lr = 0.0
|
||||
|
||||
# decay stage: decay the learning rate using the cosine schedule
|
||||
decay_lr = tf.train.cosine_decay(
|
||||
FLAGS.learning_rate,
|
||||
global_step=global_step-FLAGS.warmup_steps,
|
||||
decay_steps=FLAGS.train_steps-FLAGS.warmup_steps,
|
||||
alpha=FLAGS.min_lr_ratio)
|
||||
|
||||
# choose warmup or decay
|
||||
learning_rate = tf.where(global_step < FLAGS.warmup_steps,
|
||||
warmup_lr, decay_lr)
|
||||
|
||||
# get the train op
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
|
||||
train_op = optimizer.apply_gradients(grads_and_vars, global_step)
|
||||
|
||||
##### Training loop
|
||||
tower_mems_np = [
|
||||
[np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
|
||||
for layer in range(FLAGS.n_layer)]
|
||||
for core in range(FLAGS.num_core_per_host)
|
||||
]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.warm_start_path is not None:
|
||||
tf.logging.info("warm start from {}".format(FLAGS.warm_start_path))
|
||||
saver.restore(sess, FLAGS.warm_start_path)
|
||||
|
||||
fetches = [loss, tower_new_mems, global_step, gnorm, learning_rate, train_op]
|
||||
|
||||
total_loss, prev_step = 0., -1
|
||||
while True:
|
||||
feed_dict = {}
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
|
||||
feed_dict[m] = m_np
|
||||
|
||||
fetched = sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
loss_np, tower_mems_np, curr_step = fetched[:3]
|
||||
total_loss += loss_np
|
||||
|
||||
if curr_step > 0 and curr_step % FLAGS.iterations == 0:
|
||||
curr_loss = total_loss / (curr_step - prev_step)
|
||||
tf.logging.info("[{}] | gnorm {:.2f} lr {:8.6f} "
|
||||
"| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
|
||||
curr_step, fetched[-3], fetched[-2],
|
||||
curr_loss, math.exp(curr_loss), curr_loss / math.log(2)))
|
||||
total_loss, prev_step = 0., curr_step
|
||||
|
||||
if curr_step > 0 and curr_step % FLAGS.save_steps == 0:
|
||||
save_path = os.path.join(FLAGS.model_dir, "model.ckpt")
|
||||
saver.save(sess, save_path)
|
||||
tf.logging.info("Model saved in path: {}".format(save_path))
|
||||
|
||||
if curr_step == FLAGS.train_steps:
|
||||
break
|
||||
|
||||
|
||||
def evaluate(n_token, cutoffs, ps_device):
|
||||
##### Get input function and model function
|
||||
eval_input_fn, eval_record_info = data_utils.get_input_fn(
|
||||
record_info_dir=FLAGS.record_info_dir,
|
||||
split=FLAGS.eval_split,
|
||||
per_host_bsz=FLAGS.eval_batch_size,
|
||||
tgt_len=FLAGS.tgt_len,
|
||||
num_core_per_host=FLAGS.num_core_per_host,
|
||||
num_hosts=1,
|
||||
use_tpu=False)
|
||||
|
||||
num_batch = eval_record_info["num_batch"]
|
||||
if FLAGS.max_eval_batch > 0:
|
||||
num_batch = FLAGS.max_eval_batch
|
||||
tf.logging.info("num of batches {}".format(num_batch))
|
||||
|
||||
##### Create computational graph
|
||||
eval_set = eval_input_fn({
|
||||
"batch_size": FLAGS.eval_batch_size,
|
||||
"data_dir": FLAGS.data_dir})
|
||||
|
||||
input_feed, label_feed = eval_set.make_one_shot_iterator().get_next()
|
||||
|
||||
inputs = tf.split(input_feed, FLAGS.num_core_per_host, 0)
|
||||
labels = tf.split(label_feed, FLAGS.num_core_per_host, 0)
|
||||
|
||||
per_core_bsz = FLAGS.eval_batch_size // FLAGS.num_core_per_host
|
||||
tower_mems, tower_losses, tower_new_mems = [], [], []
|
||||
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
with tf.DEVICE(assign_to_gpu(i, ps_device)), \
|
||||
tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
|
||||
|
||||
mems_i = [tf.placeholder(tf.float32,
|
||||
[FLAGS.mem_len, per_core_bsz, FLAGS.d_model])
|
||||
for _ in range(FLAGS.n_layer)]
|
||||
|
||||
loss_i, new_mems_i = single_core_graph(
|
||||
n_token=n_token,
|
||||
cutoffs=cutoffs,
|
||||
is_training=False,
|
||||
inp=inputs[i],
|
||||
tgt=labels[i],
|
||||
mems=mems_i)
|
||||
|
||||
tower_mems.append(mems_i)
|
||||
tower_losses.append(loss_i)
|
||||
tower_new_mems.append(new_mems_i)
|
||||
|
||||
## sum losses across towers
|
||||
if len(tower_losses) > 1:
|
||||
loss = tf.add_n(tower_losses) / len(tower_losses)
|
||||
else:
|
||||
loss = tower_losses[0]
|
||||
|
||||
##### Evaluation loop
|
||||
tower_mems_np = [
|
||||
[np.zeros([FLAGS.mem_len, per_core_bsz, FLAGS.d_model], dtype=np.float32)
|
||||
for layer in range(FLAGS.n_layer)]
|
||||
for core in range(FLAGS.num_core_per_host)
|
||||
]
|
||||
|
||||
saver = tf.train.Saver()
|
||||
|
||||
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
if FLAGS.eval_ckpt_path is None:
|
||||
eval_ckpt_path = tf.train.latest_checkpoint(FLAGS.model_dir)
|
||||
else:
|
||||
eval_ckpt_path = FLAGS.eval_ckpt_path
|
||||
tf.logging.info("Evaluate {}".format(eval_ckpt_path))
|
||||
saver.restore(sess, eval_ckpt_path)
|
||||
|
||||
fetches = [loss, tower_new_mems, tf.size(label_feed)]
|
||||
|
||||
format_str = " >> processing batch {{:{0}d}}/{{:{0}d}} ..".format(
|
||||
len(str(num_batch)))
|
||||
|
||||
total_loss, total_cnt = 0, 0
|
||||
for step in range(num_batch):
|
||||
if step % (num_batch // 10) == 0:
|
||||
tf.logging.info(format_str.format(step, num_batch))
|
||||
|
||||
feed_dict = {}
|
||||
for i in range(FLAGS.num_core_per_host):
|
||||
for m, m_np in zip(tower_mems[i], tower_mems_np[i]):
|
||||
feed_dict[m] = m_np
|
||||
|
||||
fetched = sess.run(fetches, feed_dict=feed_dict)
|
||||
|
||||
loss_np, tower_mems_np, cnt_np = fetched[:3]
|
||||
total_loss += loss_np * cnt_np
|
||||
total_cnt += cnt_np
|
||||
|
||||
avg_loss = total_loss / total_cnt
|
||||
tf.logging.info("| loss {:.2f} | pplx {:>7.2f}, bpc {:>7.4f}".format(
|
||||
avg_loss, math.exp(avg_loss), avg_loss / math.log(2)))
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
del unused_argv # Unused
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
# Get corpus info
|
||||
corpus_info = data_utils.get_corpus_info(FLAGS.corpus_info_path)
|
||||
n_token = corpus_info["vocab_size"]
|
||||
cutoffs = corpus_info["cutoffs"][1:-1]
|
||||
tf.logging.info("n_token {}".format(n_token))
|
||||
|
||||
if FLAGS.do_train:
|
||||
train(n_token, cutoffs, "/gpu:0")
|
||||
if FLAGS.do_eval:
|
||||
evaluate(n_token, cutoffs, "/gpu:0")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.gfile import Open as open
|
||||
from tensorflow.gfile import Exists as exists
|
||||
|
||||
class Vocab(object):
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||
delimiter=None, vocab_file=None):
|
||||
self.counter = Counter()
|
||||
self.special = special
|
||||
self.min_freq = min_freq
|
||||
self.max_size = max_size
|
||||
self.lower_case = lower_case
|
||||
self.delimiter = delimiter
|
||||
self.vocab_file = vocab_file
|
||||
|
||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||
line = line.strip()
|
||||
# convert to lower case
|
||||
if self.lower_case:
|
||||
line = line.lower()
|
||||
|
||||
# empty delimiter '' will evaluate False
|
||||
if self.delimiter == '':
|
||||
symbols = line
|
||||
else:
|
||||
symbols = line.split(self.delimiter)
|
||||
|
||||
if add_double_eos: # lm1b
|
||||
return ['<S>'] + symbols + ['<S>']
|
||||
elif add_eos:
|
||||
return symbols + ['<eos>']
|
||||
else:
|
||||
return symbols
|
||||
|
||||
def count_file(self, path, verbose=False, add_eos=False):
|
||||
if verbose: print('counting file {} ...'.format(path))
|
||||
assert exists(path)
|
||||
|
||||
sents = []
|
||||
with open(path, 'r') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos)
|
||||
self.counter.update(symbols)
|
||||
sents.append(symbols)
|
||||
|
||||
return sents
|
||||
|
||||
def count_sents(self, sents, verbose=False):
|
||||
"""
|
||||
sents : a list of sentences, each a list of tokenized symbols
|
||||
"""
|
||||
if verbose: print('counting {} sents ...'.format(len(sents)))
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
self.counter.update(symbols)
|
||||
|
||||
def _build_from_file(self, vocab_file):
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
|
||||
with open(vocab_file, 'r') as f:
|
||||
for line in f:
|
||||
symb = line.strip().split()[0]
|
||||
self.add_symbol(symb)
|
||||
self.unk_idx = self.sym2idx['<UNK>']
|
||||
|
||||
def build_vocab(self):
|
||||
if self.vocab_file:
|
||||
print('building vocab from {}'.format(self.vocab_file))
|
||||
self._build_from_file(self.vocab_file)
|
||||
print('final vocab size {}'.format(len(self)))
|
||||
else:
|
||||
print('building vocab with min_freq={}, max_size={}'.format(
|
||||
self.min_freq, self.max_size))
|
||||
self.idx2sym = []
|
||||
self.sym2idx = OrderedDict()
|
||||
|
||||
for sym in self.special:
|
||||
self.add_special(sym)
|
||||
|
||||
for sym, cnt in self.counter.most_common(self.max_size):
|
||||
if cnt < self.min_freq: break
|
||||
self.add_symbol(sym)
|
||||
|
||||
print('final vocab size {} from {} unique tokens'.format(
|
||||
len(self), len(self.counter)))
|
||||
|
||||
def encode_file(self, path, ordered=False, verbose=False, add_eos=True,
|
||||
add_double_eos=False):
|
||||
if verbose: print('encoding file {} ...'.format(path))
|
||||
assert exists(path)
|
||||
encoded = []
|
||||
with open(path, 'r') as f:
|
||||
for idx, line in enumerate(f):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
symbols = self.tokenize(line, add_eos=add_eos,
|
||||
add_double_eos=add_double_eos)
|
||||
encoded.append(self.convert_to_nparray(symbols))
|
||||
|
||||
if ordered:
|
||||
encoded = np.concatenate(encoded)
|
||||
|
||||
return encoded
|
||||
|
||||
def encode_sents(self, sents, ordered=False, verbose=False):
|
||||
if verbose: print('encoding {} sents ...'.format(len(sents)))
|
||||
encoded = []
|
||||
for idx, symbols in enumerate(sents):
|
||||
if verbose and idx > 0 and idx % 500000 == 0:
|
||||
print(' line {}'.format(idx))
|
||||
encoded.append(self.convert_to_nparray(symbols))
|
||||
|
||||
if ordered:
|
||||
encoded = np.concatenate(encoded)
|
||||
|
||||
return encoded
|
||||
|
||||
def add_special(self, sym):
|
||||
if sym not in self.sym2idx:
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym])
|
||||
|
||||
def add_symbol(self, sym):
|
||||
if sym not in self.sym2idx:
|
||||
self.idx2sym.append(sym)
|
||||
self.sym2idx[sym] = len(self.idx2sym) - 1
|
||||
|
||||
def get_sym(self, idx):
|
||||
assert 0 <= idx < len(self), 'Index {} out of range'.format(idx)
|
||||
return self.idx2sym[idx]
|
||||
|
||||
def get_idx(self, sym):
|
||||
if sym in self.sym2idx:
|
||||
return self.sym2idx[sym]
|
||||
else:
|
||||
assert hasattr(self, 'unk_idx')
|
||||
return self.sym2idx.get(sym, self.unk_idx)
|
||||
|
||||
def get_symbols(self, indices):
|
||||
return [self.get_sym(idx) for idx in indices]
|
||||
|
||||
def get_indices(self, symbols):
|
||||
return [self.get_idx(sym) for sym in symbols]
|
||||
|
||||
def convert_to_nparray(self, symbols):
|
||||
nparray = np.array(self.get_indices(symbols), dtype=np.int64)
|
||||
return nparray
|
||||
|
||||
def convert_to_sent(self, indices, exclude=None):
|
||||
if exclude is None:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices])
|
||||
else:
|
||||
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.idx2sym)
|
||||
Reference in a new issue