feat: Add model choice

This commit is contained in:
Tibo De Peuter 2025-12-06 21:52:31 +01:00
parent bb241154d9
commit ef50d6321e
Signed by: tdpeuter
GPG key ID: 38297DE43F75FFE2
10 changed files with 102 additions and 54 deletions

14
src/models/Model.py Normal file
View file

@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
from torch import nn
class Model(nn.Module, ABC):
@abstractmethod
def __init__(self, loss_function = None):
super().__init__()
self._loss_function = loss_function
@property
def loss_function(self):
return self._loss_function

View file

@ -1,2 +1,9 @@
from .Model import Model
from .cnn import CNNPredictor
from .transformer import Transformer
from .transformer import Transformer
model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor,
'transformer': Transformer
}

View file

@ -1,14 +1,16 @@
import torch
import torch.nn as nn
class CNNPredictor(nn.Module):
from src.models import Model
class CNNPredictor(Model):
def __init__(
self,
vocab_size=256,
embed_dim=64,
hidden_dim=128,
):
super().__init__()
super().__init__(nn.CrossEntropyLoss())
# 1. Embedding: maps bytes (0255) → vectors
self.embed = nn.Embedding(vocab_size, embed_dim)

View file

@ -30,6 +30,7 @@ class Transformer(nn.Transformer):
device=None,
dtype=None
)
self.loss_function = nn.CrossEntropyLoss()
def forward(
self,