This repository has been archived on 2025-12-23. You can view files and clone it, but you cannot make any changes to it's state, such as pushing and creating new issues, pull requests or comments.
2025ML-project-neural_compr.../src/models/autoencoder/autoencoder.py

68 lines
No EOL
2.3 KiB
Python

import torch
import torch.nn as nn
from src.models import Model
class Encoder(nn.Module):
def __init__(self, data_length, channel_count, latent_dim):
super(Encoder, self).__init__()
self._encoder = nn.Sequential(*[
nn.Conv1d(1, channel_count, kernel_size=3, padding=1), # (hidden_size, L)
nn.BatchNorm1d(channel_count),
nn.ReLU(),
nn.Conv1d(channel_count, 2 * channel_count, stride=2, kernel_size=3, padding=1), # (2 * hidden_size, L / 2)
nn.BatchNorm1d(2 * channel_count),
nn.Flatten(), # 2 * hidden_size * L / 2
nn.Linear(2 * channel_count * data_length // 2, latent_dim),
nn.ReLU()
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._encoder(x)
class Decoder(nn.Module):
def __init__(self, latent_dim, channel_count, data_length):
super(Decoder, self).__init__()
self._decoder = nn.Sequential(*[
nn.Linear(latent_dim, 2 * channel_count * data_length // 2),
nn.ReLU(),
nn.Unflatten(1, (2 * channel_count, data_length // 2)),
nn.BatchNorm1d(2 * channel_count),
nn.ConvTranspose1d(2 * channel_count, channel_count, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm1d(channel_count),
nn.ReLU(),
nn.ConvTranspose1d(channel_count, 1, kernel_size=3, padding=1),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self._decoder(x)
class AutoEncoder(Model):
def __init__(self, input_size, channel_count, latent_dim):
super().__init__(loss_function = nn.MSELoss())
self.encoder = Encoder(input_size, channel_count, latent_dim)
self.decoder = Decoder(latent_dim, channel_count, input_size)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor:
x = x.float() / 255.0 # convert to floats
x = x.unsqueeze(1) # add channel dimension --> (B, 1, L)
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded