import torch import torch.nn as nn from src.models import Model class Encoder(nn.Module): def __init__(self, data_length, channel_count, latent_dim): super(Encoder, self).__init__() self._encoder = nn.Sequential(*[ nn.Conv1d(1, channel_count, kernel_size=3, padding=1), # (hidden_size, L) nn.BatchNorm1d(channel_count), nn.ReLU(), nn.Conv1d(channel_count, 2 * channel_count, stride=2, kernel_size=3, padding=1), # (2 * hidden_size, L / 2) nn.BatchNorm1d(2 * channel_count), nn.Flatten(), # 2 * hidden_size * L / 2 nn.Linear(2 * channel_count * data_length // 2, latent_dim), nn.ReLU() ]) def forward(self, x: torch.Tensor) -> torch.Tensor: return self._encoder(x) class Decoder(nn.Module): def __init__(self, latent_dim, channel_count, data_length): super(Decoder, self).__init__() self._decoder = nn.Sequential(*[ nn.Linear(latent_dim, 2 * channel_count * data_length // 2), nn.ReLU(), nn.Unflatten(1, (2 * channel_count, data_length // 2)), nn.BatchNorm1d(2 * channel_count), nn.ConvTranspose1d(2 * channel_count, channel_count, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm1d(channel_count), nn.ReLU(), nn.ConvTranspose1d(channel_count, 1, kernel_size=3, padding=1), ]) def forward(self, x: torch.Tensor) -> torch.Tensor: return self._decoder(x) class AutoEncoder(Model): def __init__(self, input_size, channel_count, latent_dim): super().__init__(loss_function = nn.MSELoss()) self.encoder = Encoder(input_size, channel_count, latent_dim) self.decoder = Decoder(latent_dim, channel_count, input_size) def encode(self, x: torch.Tensor) -> torch.Tensor: """ x: torch.Tensor of floats """ return self.encoder(x) def decode(self, x: torch.Tensor) -> torch.Tensor: """ x: torch.Tensor of floats """ return self.decoder(x) def forward(self, x: torch.LongTensor) -> torch.Tensor: x = x.float() / 255.0 # convert to floats x = x.unsqueeze(1) # add channel dimension --> (B, 1, L) encoded = self.encoder(x) decoded = self.decoder(encoded) return decoded