feat: Add model choice
This commit is contained in:
parent
bb241154d9
commit
ef50d6321e
10 changed files with 102 additions and 54 deletions
14
src/models/Model.py
Normal file
14
src/models/Model.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Model(nn.Module, ABC):
|
||||
@abstractmethod
|
||||
def __init__(self, loss_function = None):
|
||||
super().__init__()
|
||||
self._loss_function = loss_function
|
||||
|
||||
@property
|
||||
def loss_function(self):
|
||||
return self._loss_function
|
||||
|
|
@ -1,2 +1,9 @@
|
|||
from .Model import Model
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import Transformer
|
||||
from .transformer import Transformer
|
||||
|
||||
|
||||
model_called: dict[str, type[Model]] = {
|
||||
'cnn': CNNPredictor,
|
||||
'transformer': Transformer
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,16 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class CNNPredictor(nn.Module):
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class CNNPredictor(Model):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=256,
|
||||
embed_dim=64,
|
||||
hidden_dim=128,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(nn.CrossEntropyLoss())
|
||||
|
||||
# 1. Embedding: maps bytes (0–255) → vectors
|
||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class Transformer(nn.Transformer):
|
|||
device=None,
|
||||
dtype=None
|
||||
)
|
||||
self.loss_function = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
Reference in a new issue