This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/utils/load_data.py
2025-11-07 12:54:36 +01:00

264 lines
7.4 KiB
Python

from __future__ import print_function
import numbers
import torch
import torch.utils.data as data_utils
import pickle
from scipy.io import loadmat
import numpy as np
import os
from torch.utils.data import Dataset
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as vf
from torch.utils.data import ConcatDataset
from PIL import Image
import os
import os.path
from os.path import join
import sys
import tarfile
class ToTensorNoNorm():
def __call__(self, X_i):
return torch.from_numpy(np.array(X_i, copy=False)).permute(2, 0, 1)
class PadToMultiple(object):
def __init__(self, multiple, fill=0, padding_mode='constant'):
assert isinstance(multiple, numbers.Number)
assert isinstance(fill, (numbers.Number, str, tuple))
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
self.multiple = multiple
self.fill = fill
self.padding_mode = padding_mode
def __call__(self, img):
"""
Args:
img (PIL Image): Image to be padded.
Returns:
PIL Image: Padded image.
"""
w, h = img.size
m = self.multiple
nw = (w // m + int((w % m) != 0)) * m
nh = (h // m + int((h % m) != 0)) * m
padw = nw - w
padh = nh - h
out = vf.pad(img, (0, 0, padw, padh), self.fill, self.padding_mode)
return out
def __repr__(self):
return self.__class__.__name__ + '(multiple={0}, fill={1}, padding_mode={2})'.\
format(self.mulitple, self.fill, self.padding_mode)
class CustomTensorDataset(Dataset):
"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
def __init__(self, *tensors, transform=None):
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transform = transform
def __getitem__(self, index):
from PIL import Image
X, y = self.tensors
X_i, y_i, = X[index], y[index]
if self.transform:
X_i = self.transform(X_i)
X_i = torch.from_numpy(np.array(X_i, copy=False))
X_i = X_i.permute(2, 0, 1)
return X_i, y_i
def __len__(self):
return self.tensors[0].size(0)
def load_cifar10(args, **kwargs):
# set args
args.input_size = [3, 32, 32]
args.input_type = 'continuous'
args.dynamic_binarization = False
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.transpose(0, 3, 1, 2)
x_test = x_test.transpose(0, 3, 1, 2)
import math
if args.data_augmentation_level == 2:
data_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
transforms.Pad(int(math.ceil(32 * 0.05)), padding_mode='edge'),
transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
transforms.CenterCrop(32)
])
elif args.data_augmentation_level == 1:
data_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomHorizontalFlip(),
])
else:
data_transform = transforms.Compose([
transforms.ToPILImage(),
])
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]
train = CustomTensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train), transform=data_transform)
train_loader = data_utils.DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
validation = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
val_loader = data_utils.DataLoader(validation, batch_size=args.batch_size, shuffle=False, **kwargs)
test = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
test_loader = data_utils.DataLoader(test, batch_size=args.batch_size, shuffle=False, **kwargs)
return train_loader, val_loader, test_loader, args
def extract_tar(tarpath):
assert tarpath.endswith('.tar')
startdir = tarpath[:-4] + '/'
if os.path.exists(startdir):
return startdir
print('Extracting', tarpath)
with tarfile.open(name=tarpath) as tar:
t = 0
done = False
while not done:
path = join(startdir, 'images{}'.format(t))
os.makedirs(path, exist_ok=True)
print(path)
for i in range(50000):
member = tar.next()
if member is None:
done = True
break
# Skip directories
while member.isdir():
member = tar.next()
if member is None:
done = True
break
member.name = member.name.split('/')[-1]
tar.extract(member, path=path)
t += 1
return startdir
def load_imagenet(resolution, args, **kwargs):
assert resolution == 32 or resolution == 64
args.input_size = [3, resolution, resolution]
trainpath = '../imagenet{res}/train_{res}x{res}.tar'.format(res=resolution)
valpath = '../imagenet{res}/valid_{res}x{res}.tar'.format(res=resolution)
trainpath = extract_tar(trainpath)
valpath = extract_tar(valpath)
data_transform = transforms.Compose([
ToTensorNoNorm()
])
print('Starting loading ImageNet')
imagenet_data = torchvision.datasets.ImageFolder(
trainpath,
transform=data_transform)
print('Number of data images', len(imagenet_data))
val_idcs = np.random.choice(len(imagenet_data), size=20000, replace=False)
train_idcs = np.setdiff1d(np.arange(len(imagenet_data)), val_idcs)
train_dataset = torch.utils.data.dataset.Subset(
imagenet_data, train_idcs)
val_dataset = torch.utils.data.dataset.Subset(
imagenet_data, val_idcs)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
**kwargs)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
**kwargs)
test_dataset = torchvision.datasets.ImageFolder(
valpath,
transform=data_transform)
print('Number of val images:', len(test_dataset))
test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=args.batch_size,
shuffle=False,
**kwargs)
return train_loader, val_loader, test_loader, args
def load_dataset(args, **kwargs):
if args.dataset == 'cifar10':
train_loader, val_loader, test_loader, args = load_cifar10(args, **kwargs)
elif args.dataset == 'imagenet32':
train_loader, val_loader, test_loader, args = load_imagenet(32, args, **kwargs)
elif args.dataset == 'imagenet64':
train_loader, val_loader, test_loader, args = load_imagenet(64, args, **kwargs)
else:
raise Exception('Wrong name of the dataset!')
return train_loader, val_loader, test_loader, args
if __name__ == '__main__':
class Args():
def __init__(self):
self.batch_size = 128
train_loader, val_loader, test_loader, args = load_imagenet32(Args())