feat: new CNN, start of creating graphs

This commit is contained in:
RobinMeersman 2025-12-14 18:36:40 +01:00
parent 17e0b52600
commit 5bb254d6c2
7 changed files with 151 additions and 49 deletions

View file

@ -50,12 +50,16 @@ class AutoEncoder(Model):
"""
x: torch.Tensor of floats
"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor: