264 lines
7.4 KiB
Python
264 lines
7.4 KiB
Python
from __future__ import print_function
|
|
|
|
import numbers
|
|
|
|
import torch
|
|
import torch.utils.data as data_utils
|
|
import pickle
|
|
from scipy.io import loadmat
|
|
|
|
import numpy as np
|
|
|
|
import os
|
|
from torch.utils.data import Dataset
|
|
import torchvision
|
|
from torchvision import transforms
|
|
from torchvision.transforms import functional as vf
|
|
from torch.utils.data import ConcatDataset
|
|
|
|
from PIL import Image
|
|
|
|
import os
|
|
import os.path
|
|
from os.path import join
|
|
import sys
|
|
import tarfile
|
|
|
|
|
|
class ToTensorNoNorm():
|
|
def __call__(self, X_i):
|
|
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
|
|
|
|
|
|
class PadToMultiple(object):
|
|
def __init__(self, multiple, fill=0, padding_mode='constant'):
|
|
assert isinstance(multiple, numbers.Number)
|
|
assert isinstance(fill, (numbers.Number, str, tuple))
|
|
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
|
|
|
|
self.multiple = multiple
|
|
self.fill = fill
|
|
self.padding_mode = padding_mode
|
|
|
|
def __call__(self, img):
|
|
"""
|
|
Args:
|
|
img (PIL Image): Image to be padded.
|
|
Returns:
|
|
PIL Image: Padded image.
|
|
"""
|
|
w, h = img.size
|
|
m = self.multiple
|
|
nw = (w // m + int((w % m) != 0)) * m
|
|
nh = (h // m + int((h % m) != 0)) * m
|
|
padw = nw - w
|
|
padh = nh - h
|
|
|
|
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
|
|
return out
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
|
|
format(self.mulitple, self.fill, self.padding_mode)
|
|
|
|
|
|
class CustomTensorDataset(Dataset):
|
|
"""Dataset wrapping tensors.
|
|
|
|
Each sample will be retrieved by indexing tensors along the first dimension.
|
|
|
|
Arguments:
|
|
*tensors (Tensor): tensors that have the same size of the first dimension.
|
|
"""
|
|
|
|
def __init__(self, *tensors, transform=None):
|
|
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
|
|
self.tensors = tensors
|
|
self.transform = transform
|
|
|
|
def __getitem__(self, index):
|
|
from PIL import Image
|
|
|
|
X, y = self.tensors
|
|
X_i, y_i, = X[index], y[index]
|
|
|
|
if self.transform:
|
|
X_i = self.transform(X_i)
|
|
X_i = torch.from_numpy(np.array(X_i, copy=False))
|
|
X_i = X_i.permute(2, 0, 1)
|
|
|
|
return X_i, y_i
|
|
|
|
def __len__(self):
|
|
return self.tensors[0].size(0)
|
|
|
|
|
|
def load_cifar10(args, **kwargs):
|
|
# set args
|
|
args.input_size = [3, 32, 32]
|
|
args.input_type = 'continuous'
|
|
args.dynamic_binarization = False
|
|
|
|
from keras.datasets import cifar10
|
|
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
|
|
|
|
x_train = x_train.transpose(0, 3, 1, 2)
|
|
x_test = x_test.transpose(0, 3, 1, 2)
|
|
|
|
import math
|
|
|
|
if args.data_augmentation_level == 2:
|
|
data_transform = transforms.Compose([
|
|
transforms.ToPILImage(),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
|
|
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
|
|
transforms.CenterCrop(32)
|
|
])
|
|
elif args.data_augmentation_level == 1:
|
|
data_transform = transforms.Compose([
|
|
transforms.ToPILImage(),
|
|
transforms.RandomHorizontalFlip(),
|
|
])
|
|
else:
|
|
data_transform = transforms.Compose([
|
|
transforms.ToPILImage(),
|
|
])
|
|
|
|
x_val = x_train[-10000:]
|
|
y_val = y_train[-10000:]
|
|
|
|
x_train = x_train[:-10000]
|
|
y_train = y_train[:-10000]
|
|
|
|
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
|
|
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
|
|
|
|
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
|
|
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
|
|
|
|
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
|
|
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
|
|
|
|
return train_loader, val_loader, test_loader, args
|
|
|
|
|
|
def extract_tar(tarpath):
|
|
assert tarpath.endswith('.tar')
|
|
|
|
startdir = tarpath[:-4] + '/'
|
|
|
|
if os.path.exists(startdir):
|
|
return startdir
|
|
|
|
print('Extracting', tarpath)
|
|
|
|
with tarfile.open(name=tarpath) as tar:
|
|
t = 0
|
|
done = False
|
|
while not done:
|
|
path = join(startdir, 'images{}'.format(t))
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
print(path)
|
|
|
|
for i in range(50000):
|
|
member = tar.next()
|
|
|
|
if member is None:
|
|
done = True
|
|
break
|
|
|
|
# Skip directories
|
|
while member.isdir():
|
|
member = tar.next()
|
|
if member is None:
|
|
done = True
|
|
break
|
|
|
|
member.name = member.name.split('/')[-1]
|
|
|
|
tar.extract(member, path=path)
|
|
|
|
t += 1
|
|
|
|
return startdir
|
|
|
|
|
|
def load_imagenet(resolution, args, **kwargs):
|
|
assert resolution == 32 or resolution == 64
|
|
|
|
args.input_size = [3, resolution, resolution]
|
|
|
|
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
|
|
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
|
|
|
|
trainpath = extract_tar(trainpath)
|
|
valpath = extract_tar(valpath)
|
|
|
|
data_transform = transforms.Compose([
|
|
ToTensorNoNorm()
|
|
])
|
|
|
|
print('Starting loading ImageNet')
|
|
|
|
imagenet_data = torchvision.datasets.ImageFolder(
|
|
trainpath,
|
|
transform=data_transform)
|
|
|
|
print('Number of data images', len(imagenet_data))
|
|
|
|
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
|
|
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
|
|
|
|
train_dataset = torch.utils.data.dataset.Subset(
|
|
imagenet_data, train_idcs)
|
|
val_dataset = torch.utils.data.dataset.Subset(
|
|
imagenet_data, val_idcs)
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=True,
|
|
**kwargs)
|
|
|
|
val_loader = torch.utils.data.DataLoader(
|
|
val_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
**kwargs)
|
|
|
|
test_dataset = torchvision.datasets.ImageFolder(
|
|
valpath,
|
|
transform=data_transform)
|
|
|
|
print('Number of val images:', len(test_dataset))
|
|
|
|
test_loader = torch.utils.data.DataLoader(
|
|
test_dataset,
|
|
batch_size=args.batch_size,
|
|
shuffle=False,
|
|
**kwargs)
|
|
|
|
return train_loader, val_loader, test_loader, args
|
|
|
|
|
|
def load_dataset(args, **kwargs):
|
|
|
|
if args.dataset == 'cifar10':
|
|
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
|
|
elif args.dataset == 'imagenet32':
|
|
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
|
|
elif args.dataset == 'imagenet64':
|
|
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
|
|
else:
|
|
raise Exception('Wrong name of the dataset!')
|
|
|
|
return train_loader, val_loader, test_loader, args
|
|
|
|
|
|
if __name__ == '__main__':
|
|
class Args():
|
|
def __init__(self):
|
|
self.batch_size = 128
|
|
train_loader, val_loader, test_loader, args = load_imagenet32(Args())
|