backup
This commit is contained in:
parent
41d22d9dd5
commit
6e591bb470
3 changed files with 13 additions and 8 deletions
|
|
@ -1,8 +1,10 @@
|
||||||
from .Model import Model
|
from .Model import Model
|
||||||
|
from .autoencoder import AutoEncoder
|
||||||
from .cnn import CNNPredictor
|
from .cnn import CNNPredictor
|
||||||
from .transformer import ByteTransformer
|
from .transformer import ByteTransformer
|
||||||
|
|
||||||
model_called: dict[str, type[Model]] = {
|
model_called: dict[str, type[Model]] = {
|
||||||
'cnn': CNNPredictor,
|
'cnn': CNNPredictor,
|
||||||
'transformer': ByteTransformer
|
'transformer': ByteTransformer,
|
||||||
|
'autoencoder': AutoEncoder
|
||||||
}
|
}
|
||||||
|
|
|
||||||
1
src/models/autoencoder/__init__.py
Normal file
1
src/models/autoencoder/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
from .autoencoder import AutoEncoder
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from src.models import Model
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size, latent_dim):
|
def __init__(self, input_size, hidden_size, latent_dim):
|
||||||
|
|
@ -23,21 +25,21 @@ class Decoder(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size, output_size):
|
def __init__(self, input_size, hidden_size, output_size):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
super._decoder = nn.Sequential(*[
|
super._decoder = nn.Sequential(*[
|
||||||
nn.Linear(input_size),
|
nn.Linear(input_size, 2 * hidden_size),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.BatchNorm1d(input_size),
|
|
||||||
nn.ConvTranspose1d(input_size, 2 * hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1),
|
|
||||||
nn.BatchNorm1d(2 * hidden_size),
|
nn.BatchNorm1d(2 * hidden_size),
|
||||||
|
nn.ConvTranspose1d(2 * hidden_size, hidden_size, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||||
|
nn.BatchNorm1d(hidden_size),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.ConvTranspose1d(2 * hidden_size, output_size, kernel_size=3, padding=1),
|
nn.ConvTranspose1d(hidden_size, output_size, kernel_size=3, padding=1),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
pass
|
return self._decoder(x)
|
||||||
|
|
||||||
class AutoEncoder(nn.Module):
|
class AutoEncoder(Model):
|
||||||
def __init__(self, input_size, hidden_size, latent_dim):
|
def __init__(self, input_size, hidden_size, latent_dim):
|
||||||
super(AutoEncoder, self).__init__()
|
super().__init__(loss_function = nn.CrossEntropyLoss())
|
||||||
|
|
||||||
self.encoder = Encoder(input_size, hidden_size, latent_dim)
|
self.encoder = Encoder(input_size, hidden_size, latent_dim)
|
||||||
self.decoder = Decoder(latent_dim, hidden_size, input_size)
|
self.decoder = Decoder(latent_dim, hidden_size, input_size)
|
||||||
Reference in a new issue