feat: new CNN, start of creating graphs
This commit is contained in:
parent
17e0b52600
commit
5bb254d6c2
7 changed files with 151 additions and 49 deletions
|
|
@ -50,12 +50,16 @@ class AutoEncoder(Model):
|
|||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(1)
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(1)
|
||||
return self.decoder(x)
|
||||
|
||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||
|
|
|
|||
Reference in a new issue