feat: autoencoder + updated trainers + cleaned up process to allow using autoencoder
This commit is contained in:
parent
0ab495165f
commit
17e0b52600
11 changed files with 132 additions and 211 deletions
|
|
@ -47,9 +47,15 @@ class AutoEncoder(Model):
|
|||
self.decoder = Decoder(latent_dim, channel_count, input_size)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
return self.decoder(x)
|
||||
|
||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||
|
|
|
|||
Reference in a new issue