feat: measuring code + graph generator code

This commit is contained in:
RobinMeersman 2025-12-15 22:53:32 +01:00
parent dd0b3d3945
commit f3b07c1df3
6 changed files with 325 additions and 140 deletions

View file

@ -58,8 +58,6 @@ class AutoEncoder(Model):
"""
x: torch.Tensor of floats
"""
if len(x.shape) == 2:
x = x.unsqueeze(1)
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor: