""" Collection of flow strategies """ from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F from models.utils import Base UNIT_TESTING = False class Conv2dReLU(Base): def __init__( self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0, bias=True): super().__init__() self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding) def forward(self, x): h = self.nn(x) y = F.relu(h) return y class ResidualBlock(Base): def __init__(self, n_channels, kernel, Conv2dAct): super().__init__() self.nn = torch.nn.Sequential( Conv2dAct(n_channels, n_channels, kernel, padding=1), torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1), ) def forward(self, x): h = self.nn(x) h = F.relu(h + x) return h class DenseLayer(Base): def __init__(self, args, n_inputs, growth, Conv2dAct): super().__init__() conv1x1 = Conv2dAct( n_inputs, n_inputs, kernel_size=1, stride=1, padding=0, bias=True) self.nn = torch.nn.Sequential( conv1x1, Conv2dAct( n_inputs, growth, kernel_size=3, stride=1, padding=1, bias=True), ) def forward(self, x): h = self.nn(x) h = torch.cat([x, h], dim=1) return h class DenseBlock(Base): def __init__( self, args, n_inputs, n_outputs, kernel, Conv2dAct): super().__init__() depth = args.densenet_depth future_growth = n_outputs - n_inputs layers = [] for d in range(depth): growth = future_growth // (depth - d) layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct)) n_inputs += growth future_growth -= growth self.nn = torch.nn.Sequential(*layers) def forward(self, x): return self.nn(x) class Identity(Base): def __init__(self): super.__init__() def forward(self, x): return x class NN(Base): def __init__( self, args, c_in, c_out, height, width, nn_type, kernel=3): super().__init__() Conv2dAct = Conv2dReLU n_channels = args.n_channels if nn_type == 'shallow': if args.network1x1 == 'standard': conv1x1 = Conv2dAct( n_channels, n_channels, kernel_size=1, stride=1, padding=0, bias=False) layers = [ Conv2dAct(c_in, n_channels, kernel, padding=1), conv1x1] layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)] elif nn_type == 'resnet': layers = [ Conv2dAct(c_in, n_channels, kernel, padding=1), ResidualBlock(n_channels, kernel, Conv2dAct), ResidualBlock(n_channels, kernel, Conv2dAct)] layers += [ torch.nn.Conv2d(n_channels, c_out, kernel, padding=1) ] elif nn_type == 'densenet': layers = [ DenseBlock( args=args, n_inputs=c_in, n_outputs=n_channels + c_in, kernel=kernel, Conv2dAct=Conv2dAct)] layers += [ torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1) ] else: raise ValueError self.nn = torch.nn.Sequential(*layers) # Set parameters of last conv-layer to zero. if not UNIT_TESTING: self.nn[-1].weight.data.zero_() self.nn[-1].bias.data.zero_() def forward(self, x): return self.nn(x)