This commit is contained in:
RobinMeersman 2025-12-13 15:08:58 +01:00
parent b178c097d8
commit 41d22d9dd5

View file

@ -3,16 +3,50 @@ import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
def __init__(self, input_size, hidden_size, latent_dim):
super(Encoder, self).__init__()
self._encoder = nn.Sequential(*[
nn.Conv1d(input_size, hidden_size, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Conv1d(hidden_size, 2 * hidden_size, stride=2, kernel_size=3, padding=1),
nn.BatchNorm1d(2 * hidden_size),
nn.Linear(2 * hidden_size, latent_dim),
nn.ReLU()
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
return self._encoder(x)
class Decoder(nn.Module):
def __init__(self):
def __init__(self, input_size, hidden_size, output_size):
super(Decoder, self).__init__()
super._decoder = nn.Sequential(*[
nn.Linear(input_size),
nn.ReLU(),
nn.BatchNorm1d(input_size),
nn.ConvTranspose1d(input_size, 2 * hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.BatchNorm1d(2 * hidden_size),
nn.ReLU(),
nn.ConvTranspose1d(2 * hidden_size, output_size, kernel_size=3, padding=1),
])
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass
class AutoEncoder(nn.Module):
def __init__(self, input_size, hidden_size, latent_dim):
super(AutoEncoder, self).__init__()
self.encoder = Encoder(input_size, hidden_size, latent_dim)
self.decoder = Decoder(latent_dim, hidden_size, input_size)
def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor:
return self.decoder(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.decode(self.encode(x))