From 41d22d9dd5b62002b8a22308c94c763854f22403 Mon Sep 17 00:00:00 2001 From: RobinMeersman Date: Sat, 13 Dec 2025 15:08:58 +0100 Subject: [PATCH] backup --- src/models/autoencoder.py | 40 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/src/models/autoencoder.py b/src/models/autoencoder.py index b66fb73..76335b6 100644 --- a/src/models/autoencoder.py +++ b/src/models/autoencoder.py @@ -3,16 +3,50 @@ import torch.nn as nn class Encoder(nn.Module): - def __init__(self, input_size, hidden_size, output_size): + def __init__(self, input_size, hidden_size, latent_dim): super(Encoder, self).__init__() + self._encoder = nn.Sequential(*[ + nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1), + nn.BatchNorm1d(hidden_size), + nn.ReLU(), + nn.Conv1d(hidden_size, 2 * hidden_size, stride=2, kernel_size=3, padding=1), + nn.BatchNorm1d(2 * hidden_size), + nn.Linear(2 * hidden_size, latent_dim), + nn.ReLU() + ]) def forward(self, x: torch.Tensor) -> torch.Tensor: - pass + return self._encoder(x) class Decoder(nn.Module): - def __init__(self): + def __init__(self, input_size, hidden_size, output_size): super(Decoder, self).__init__() + super._decoder = nn.Sequential(*[ + nn.Linear(input_size), + nn.ReLU(), + nn.BatchNorm1d(input_size), + nn.ConvTranspose1d(input_size, 2 * hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.BatchNorm1d(2 * hidden_size), + nn.ReLU(), + nn.ConvTranspose1d(2 * hidden_size, output_size, kernel_size=3, padding=1), + ]) def forward(self, x: torch.Tensor) -> torch.Tensor: pass + +class AutoEncoder(nn.Module): + def __init__(self, input_size, hidden_size, latent_dim): + super(AutoEncoder, self).__init__() + + self.encoder = Encoder(input_size, hidden_size, latent_dim) + self.decoder = Decoder(latent_dim, hidden_size, input_size) + + def encode(self, x: torch.Tensor) -> torch.Tensor: + return self.encoder(x) + + def decode(self, x: torch.Tensor) -> torch.Tensor: + return self.decoder(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.decode(self.encode(x)) \ No newline at end of file