This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../integer_discrete_flows/models/networks.py
2025-11-07 12:54:36 +01:00

154 lines
3.8 KiB
Python

"""
Collection of flow strategies
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.utils import Base
UNIT_TESTING = False
class Conv2dReLU(Base):
def __init__(
self, n_inputs, n_outputs, kernel_size=3, stride=1, padding=0,
bias=True):
super().__init__()
self.nn = nn.Conv2d(n_inputs, n_outputs, kernel_size, padding=padding)
def forward(self, x):
h = self.nn(x)
y = F.relu(h)
return y
class ResidualBlock(Base):
def __init__(self, n_channels, kernel, Conv2dAct):
super().__init__()
self.nn = torch.nn.Sequential(
Conv2dAct(n_channels, n_channels, kernel, padding=1),
torch.nn.Conv2d(n_channels, n_channels, kernel, padding=1),
)
def forward(self, x):
h = self.nn(x)
h = F.relu(h + x)
return h
class DenseLayer(Base):
def __init__(self, args, n_inputs, growth, Conv2dAct):
super().__init__()
conv1x1 = Conv2dAct(
n_inputs, n_inputs, kernel_size=1, stride=1,
padding=0, bias=True)
self.nn = torch.nn.Sequential(
conv1x1,
Conv2dAct(
n_inputs, growth, kernel_size=3, stride=1,
padding=1, bias=True),
)
def forward(self, x):
h = self.nn(x)
h = torch.cat([x, h], dim=1)
return h
class DenseBlock(Base):
def __init__(
self, args, n_inputs, n_outputs, kernel, Conv2dAct):
super().__init__()
depth = args.densenet_depth
future_growth = n_outputs - n_inputs
layers = []
for d in range(depth):
growth = future_growth // (depth - d)
layers.append(DenseLayer(args, n_inputs, growth, Conv2dAct))
n_inputs += growth
future_growth -= growth
self.nn = torch.nn.Sequential(*layers)
def forward(self, x):
return self.nn(x)
class Identity(Base):
def __init__(self):
super.__init__()
def forward(self, x):
return x
class NN(Base):
def __init__(
self, args, c_in, c_out, height, width, nn_type, kernel=3):
super().__init__()
Conv2dAct = Conv2dReLU
n_channels = args.n_channels
if nn_type == 'shallow':
if args.network1x1 == 'standard':
conv1x1 = Conv2dAct(
n_channels, n_channels, kernel_size=1, stride=1,
padding=0, bias=False)
layers = [
Conv2dAct(c_in, n_channels, kernel, padding=1),
conv1x1]
layers += [torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)]
elif nn_type == 'resnet':
layers = [
Conv2dAct(c_in, n_channels, kernel, padding=1),
ResidualBlock(n_channels, kernel, Conv2dAct),
ResidualBlock(n_channels, kernel, Conv2dAct)]
layers += [
torch.nn.Conv2d(n_channels, c_out, kernel, padding=1)
]
elif nn_type == 'densenet':
layers = [
DenseBlock(
args=args,
n_inputs=c_in,
n_outputs=n_channels + c_in,
kernel=kernel,
Conv2dAct=Conv2dAct)]
layers += [
torch.nn.Conv2d(n_channels + c_in, c_out, kernel, padding=1)
]
else:
raise ValueError
self.nn = torch.nn.Sequential(*layers)
# Set parameters of last conv-layer to zero.
if not UNIT_TESTING:
self.nn[-1].weight.data.zero_()
self.nn[-1].bias.data.zero_()
def forward(self, x):
return self.nn(x)