feat: autoencoder + updated trainers + cleaned up process to allow using autoencoder
This commit is contained in:
parent
0ab495165f
commit
17e0b52600
11 changed files with 132 additions and 211 deletions
13
main.py
13
main.py
|
|
@ -1,5 +1,5 @@
|
||||||
from src.args import parse_arguments
|
from src.args import parse_arguments
|
||||||
from src.process import compress
|
from src.process import compress, decompress
|
||||||
from src.train import train
|
from src.train import train
|
||||||
from src.utils import determine_device
|
from src.utils import determine_device
|
||||||
|
|
||||||
|
|
@ -38,12 +38,21 @@ def main():
|
||||||
|
|
||||||
case 'compress':
|
case 'compress':
|
||||||
compress(device=device,
|
compress(device=device,
|
||||||
|
model_name=args.model,
|
||||||
model_path=args.model_load_path,
|
model_path=args.model_load_path,
|
||||||
input_file=args.input_file,
|
input_file=args.input_file,
|
||||||
output_file=args.output_file,
|
output_file=args.output_file,
|
||||||
context_length=args.context
|
context_length=args.context
|
||||||
)
|
)
|
||||||
|
case 'decompress':
|
||||||
|
decompress(
|
||||||
|
device=device,
|
||||||
|
model_name=args.model,
|
||||||
|
model_path=args.model_load_path,
|
||||||
|
input_file=args.input_file,
|
||||||
|
output_file=args.output_file,
|
||||||
|
context_length=args.context
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
from .Model import Model
|
from .Model import Model
|
||||||
from .autoencoder import AutoEncoder
|
from .autoencoder import AutoEncoder
|
||||||
from .cnn import CNNPredictor
|
from .cnn import CNNPredictor
|
||||||
from .transformer import ByteTransformer
|
|
||||||
|
|
||||||
model_called: dict[str, type[Model]] = {
|
model_called: dict[str, type[Model]] = {
|
||||||
'cnn': CNNPredictor,
|
'cnn': CNNPredictor,
|
||||||
'transformer': ByteTransformer,
|
|
||||||
'autoencoder': AutoEncoder
|
'autoencoder': AutoEncoder
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,9 +47,15 @@ class AutoEncoder(Model):
|
||||||
self.decoder = Decoder(latent_dim, channel_count, input_size)
|
self.decoder = Decoder(latent_dim, channel_count, input_size)
|
||||||
|
|
||||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: torch.Tensor of floats
|
||||||
|
"""
|
||||||
return self.encoder(x)
|
return self.encoder(x)
|
||||||
|
|
||||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
x: torch.Tensor of floats
|
||||||
|
"""
|
||||||
return self.decoder(x)
|
return self.decoder(x)
|
||||||
|
|
||||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from .transformer import ByteTransformer
|
|
||||||
|
|
@ -1,70 +0,0 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor, arange
|
|
||||||
|
|
||||||
from src.models import Model
|
|
||||||
|
|
||||||
|
|
||||||
class LearnedPositionalEncoding(Model):
|
|
||||||
def __init__(self, max_len, d_model):
|
|
||||||
super().__init__()
|
|
||||||
self.pos_emb = nn.Embedding(max_len, d_model)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
# x: [seq, batch, d_model]
|
|
||||||
seq_len = x.size(0)
|
|
||||||
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
|
|
||||||
return x + self.pos_emb(positions) # broadcast over batch
|
|
||||||
|
|
||||||
class ByteTransformer(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
d_model=512,
|
|
||||||
nhead=8,
|
|
||||||
num_encoder_layers=6,
|
|
||||||
num_decoder_layers=6,
|
|
||||||
dim_feedforward=2048,
|
|
||||||
dropout=0.1,
|
|
||||||
activation="relu",
|
|
||||||
layer_norm_eps=1e-05,
|
|
||||||
max_len=128
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.src_embedding = nn.Embedding(256, d_model)
|
|
||||||
self.tgt_embedding = nn.Embedding(256, d_model)
|
|
||||||
|
|
||||||
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
|
|
||||||
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
|
|
||||||
|
|
||||||
self.transformer = nn.Transformer(
|
|
||||||
d_model=d_model,
|
|
||||||
nhead=nhead,
|
|
||||||
num_encoder_layers=num_encoder_layers,
|
|
||||||
num_decoder_layers=num_decoder_layers,
|
|
||||||
dim_feedforward=dim_feedforward,
|
|
||||||
dropout=dropout,
|
|
||||||
activation=activation,
|
|
||||||
layer_norm_eps=layer_norm_eps,
|
|
||||||
batch_first=False,
|
|
||||||
norm_first=False,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.output_proj = nn.Linear(d_model, 256)
|
|
||||||
|
|
||||||
self.loss_function = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
src: Tensor,
|
|
||||||
tgt: Tensor,
|
|
||||||
) -> Tensor:
|
|
||||||
src_embeds = self.src_embedding(src)
|
|
||||||
tgt_embeds = self.tgt_embedding(tgt)
|
|
||||||
|
|
||||||
src_pos = self.src_pos(src_embeds)
|
|
||||||
tgt_pos = self.tgt_pos(tgt_embeds)
|
|
||||||
|
|
||||||
return self.output_proj(self.transformer(src_pos, tgt_pos))
|
|
||||||
194
src/process.py
194
src/process.py
|
|
@ -1,11 +1,14 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from src.models import AutoEncoder
|
||||||
from src.utils import reference_ae
|
from src.utils import reference_ae
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,9 +25,74 @@ def probs_to_freqs(probs, total_freq=8192):
|
||||||
return freqs
|
return freqs
|
||||||
|
|
||||||
|
|
||||||
|
def ae_compress(
|
||||||
|
output_file: str,
|
||||||
|
context_length: int,
|
||||||
|
device: str,
|
||||||
|
model: nn.Module,
|
||||||
|
byte_data: bytes,
|
||||||
|
tensor: torch.Tensor
|
||||||
|
|
||||||
|
):
|
||||||
|
# Init AE
|
||||||
|
print("Initializing AE")
|
||||||
|
|
||||||
|
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
||||||
|
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
||||||
|
|
||||||
|
context = deque([0] * context_length, maxlen=context_length)
|
||||||
|
|
||||||
|
# Compress
|
||||||
|
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||||
|
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
logits = model(context_tensor)
|
||||||
|
# normalize
|
||||||
|
mean = logits.mean(dim=-1, keepdim=True)
|
||||||
|
std = logits.std(dim=-1, keepdim=True)
|
||||||
|
logits = (logits - mean) / (std + 1e-6)
|
||||||
|
print(f"logits: {logits}")
|
||||||
|
probabilities = torch.softmax(logits[0], dim=-1)
|
||||||
|
print(f"probabilities: {probabilities}")
|
||||||
|
probabilities = probabilities.detach()
|
||||||
|
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
||||||
|
|
||||||
|
# write byte to output file
|
||||||
|
enc.write(probability_table, byte)
|
||||||
|
|
||||||
|
context.append(byte)
|
||||||
|
|
||||||
|
def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
|
||||||
|
tensor_data = torch.tensor(list(x), dtype=torch.long)
|
||||||
|
shape = tensor_data.size(0)
|
||||||
|
row_count = math.ceil(shape / context_length)
|
||||||
|
pad_count = row_count * context_length - shape
|
||||||
|
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
|
||||||
|
return tensor_data.view(row_count, context_length)
|
||||||
|
|
||||||
|
def auto_encoder_compress(
|
||||||
|
data: bytes,
|
||||||
|
model: AutoEncoder,
|
||||||
|
output_file: str,
|
||||||
|
context_length: int = 128,
|
||||||
|
device: str = "cuda"
|
||||||
|
):
|
||||||
|
# convert data to chunks of context length tensors
|
||||||
|
# send the data to device
|
||||||
|
tensor = chunk_data(data, context_length).to(device)
|
||||||
|
|
||||||
|
# compress
|
||||||
|
output = model.encode(tensor)
|
||||||
|
print(output.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def compress(
|
def compress(
|
||||||
device,
|
device,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
model_name: str,
|
||||||
context_length: int = 128,
|
context_length: int = 128,
|
||||||
input_file: str | None = None,
|
input_file: str | None = None,
|
||||||
output_file: str | None = None
|
output_file: str | None = None
|
||||||
|
|
@ -48,88 +116,42 @@ def compress(
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Init AE
|
match model_name:
|
||||||
print("Initializing AE")
|
case "cnn":
|
||||||
|
ae_compress(
|
||||||
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
output_file,
|
||||||
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
context_length,
|
||||||
|
device,
|
||||||
context = deque([0] * context_length, maxlen=context_length)
|
model,
|
||||||
|
byte_data,
|
||||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
tensor
|
||||||
stage = None
|
)
|
||||||
|
case "autoencoder":
|
||||||
# Compress
|
auto_encoder_compress()
|
||||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
case _:
|
||||||
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
raise ValueError(f"Unknown model type: {model_name}")
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
logits = model(context_tensor)
|
|
||||||
#normalize
|
|
||||||
mean = logits.mean(dim=-1, keepdim=True)
|
|
||||||
std = logits.std(dim=-1, keepdim=True)
|
|
||||||
logits = (logits - mean) / (std + 1e-6)
|
|
||||||
print(f"logits: {logits}")
|
|
||||||
probabilities = torch.softmax(logits[0], dim=-1)
|
|
||||||
print(f"probabilities: {probabilities}")
|
|
||||||
probabilities = probabilities.detach()
|
|
||||||
|
|
||||||
eps = 1e-8
|
|
||||||
# np.add(probabilities, eps)
|
|
||||||
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
|
||||||
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
|
||||||
# probability_table = AE.get_probability_table(frequency_table)
|
|
||||||
|
|
||||||
enc.write(probability_table, byte)
|
|
||||||
|
|
||||||
context.append(byte)
|
|
||||||
|
|
||||||
# print("Getting encoded value")
|
|
||||||
# interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
|
||||||
# print("Encoding in binary")
|
|
||||||
# binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
|
||||||
|
|
||||||
# Pack
|
|
||||||
# val = int(binary_code, 2) if len(binary_code) else 0
|
|
||||||
# out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
|
|
||||||
|
|
||||||
# if output_file:
|
|
||||||
# print(f"Writing to {output_file}")
|
|
||||||
# with open(output_file, "w") as file:
|
|
||||||
# file.write(f"{len(byte_data)}\n")
|
|
||||||
# file.write(binary_code) # todo: temporary, decoding depends on binary string
|
|
||||||
# else:
|
|
||||||
# print(out_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def bits_to_number(bits: str) -> float:
|
|
||||||
n = 0
|
|
||||||
for i, bit in enumerate(bits, start=1):
|
|
||||||
n += int(bit) / (1 << i)
|
|
||||||
return n
|
|
||||||
|
|
||||||
|
def ae_decompress(
|
||||||
|
|
||||||
def make_cumulative(probs):
|
):
|
||||||
cumulative = []
|
pass
|
||||||
|
|
||||||
total = 0
|
def auto_encoder_decompress(
|
||||||
|
|
||||||
for prob in probs:
|
):
|
||||||
low = total
|
pass
|
||||||
high = total + prob
|
|
||||||
cumulative.append((low, high))
|
|
||||||
total = high
|
|
||||||
return cumulative
|
|
||||||
|
|
||||||
|
|
||||||
def decompress(
|
def decompress(
|
||||||
device,
|
device,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
model_name: str,
|
||||||
input_file: str,
|
input_file: str,
|
||||||
output_file: str | None = None
|
output_file: str | None = None,
|
||||||
|
context_length: int = 128
|
||||||
):
|
):
|
||||||
context_length = 128
|
|
||||||
|
|
||||||
print("Reading in the data")
|
print("Reading in the data")
|
||||||
with open(input_file, "r") as f:
|
with open(input_file, "r") as f:
|
||||||
length = int(f.readline())
|
length = int(f.readline())
|
||||||
|
|
@ -144,30 +166,10 @@ def decompress(
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
print("Decompressing")
|
match model_name:
|
||||||
context = deque([0] * context_length, maxlen=context_length)
|
case "cnn":
|
||||||
output = bytearray()
|
ae_decompress()
|
||||||
|
case "autoencoder":
|
||||||
x = bits_to_number(bytes_data)
|
auto_encoder_decompress()
|
||||||
|
case _:
|
||||||
for _ in range(length):
|
raise ValueError(f"Unknown model type: {model_name}")
|
||||||
probs = model(context)
|
|
||||||
cumulative = make_cumulative(probs)
|
|
||||||
|
|
||||||
for symbol, (low, high) in enumerate(cumulative):
|
|
||||||
if low <= x < high:
|
|
||||||
break
|
|
||||||
|
|
||||||
output.append(symbol)
|
|
||||||
context.append(chr(symbol))
|
|
||||||
|
|
||||||
interval_low, interval_high = cumulative[symbol]
|
|
||||||
interval_width = interval_high - interval_low
|
|
||||||
x = (x - interval_low) / interval_width
|
|
||||||
|
|
||||||
if output_file is not None:
|
|
||||||
with open(output_file, "wb") as f:
|
|
||||||
f.write(output)
|
|
||||||
return
|
|
||||||
|
|
||||||
print(output.decode('utf-8', errors='replace'))
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from src.dataset_loaders import dataset_called
|
from src.dataset_loaders import dataset_called
|
||||||
from src.models import model_called
|
from src.models import model_called, Model
|
||||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -27,10 +27,10 @@ def train(
|
||||||
|
|
||||||
if model_name:
|
if model_name:
|
||||||
print(f"Creating model: {model_name}")
|
print(f"Creating model: {model_name}")
|
||||||
model = model_called[model_name]
|
model: Model | type[Model] = model_called[model_name]
|
||||||
else:
|
else:
|
||||||
print("Loading model from disk")
|
print("Loading model from disk")
|
||||||
model = torch.load(model_path, weights_only=False)
|
model: Model | type[Model] = torch.load(model_path, weights_only=False)
|
||||||
|
|
||||||
dataset_common_args = {
|
dataset_common_args = {
|
||||||
'root': data_root,
|
'root': data_root,
|
||||||
|
|
|
||||||
|
|
@ -12,15 +12,15 @@ class FullTrainer(Trainer):
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: nn.Module | type[nn.Module],
|
||||||
context_length: int,
|
context_length: int,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
n_epochs: int | None,
|
n_epochs: int | None,
|
||||||
device: str
|
device: str
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
if model is None:
|
if not isinstance(model, Model):
|
||||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
raise ValueError("Model must be an instance for full training")
|
||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from .train import train
|
from .train import train
|
||||||
from .trainer import Trainer
|
from .trainer import Trainer
|
||||||
from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder
|
from ..models import Model, CNNPredictor, AutoEncoder
|
||||||
|
|
||||||
|
|
||||||
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
|
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
|
||||||
|
|
@ -15,19 +15,6 @@ def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int =
|
||||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
)
|
)
|
||||||
if model_cls is ByteTransformer:
|
|
||||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
|
||||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
|
||||||
return ByteTransformer(
|
|
||||||
d_model=context_length,
|
|
||||||
nhead=nhead,
|
|
||||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
|
||||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
|
||||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
|
||||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
|
||||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
|
||||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
|
||||||
)
|
|
||||||
if model_cls is AutoEncoder:
|
if model_cls is AutoEncoder:
|
||||||
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
|
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
|
||||||
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
|
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
|
||||||
|
|
@ -60,7 +47,7 @@ class OptunaTrainer(Trainer):
|
||||||
|
|
||||||
def execute(
|
def execute(
|
||||||
self,
|
self,
|
||||||
model: type[Model],
|
model: Model | type[Model],
|
||||||
context_length,
|
context_length,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
|
|
|
||||||
|
|
@ -4,22 +4,7 @@ import torch
|
||||||
from torch.utils.data.dataloader import DataLoader
|
from torch.utils.data.dataloader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from ..models import ByteTransformer, Model, AutoEncoder
|
from ..models import Model, AutoEncoder
|
||||||
|
|
||||||
|
|
||||||
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
|
||||||
if isinstance(model, ByteTransformer):
|
|
||||||
tgt_in = torch.cat([
|
|
||||||
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
|
|
||||||
x[:, :-1]
|
|
||||||
], dim=1)
|
|
||||||
logits = model(x, tgt_in)
|
|
||||||
|
|
||||||
# only consider the last time step of the model where the full context
|
|
||||||
# is available
|
|
||||||
return logits[:, -1, :]
|
|
||||||
return model(x)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
model: Model,
|
model: Model,
|
||||||
|
|
@ -53,9 +38,10 @@ def train(
|
||||||
y = y.long().to(device)
|
y = y.long().to(device)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
pred = _forward(model, x, device)
|
pred = model(x)
|
||||||
|
|
||||||
if isinstance(model, AutoEncoder):
|
if isinstance(model, AutoEncoder):
|
||||||
|
pred = pred.squeeze(1)
|
||||||
loss = loss_fn(pred, x.float() / 255.0)
|
loss = loss_fn(pred, x.float() / 255.0)
|
||||||
else:
|
else:
|
||||||
loss = loss_fn(pred, y)
|
loss = loss_fn(pred, y)
|
||||||
|
|
@ -74,9 +60,11 @@ def train(
|
||||||
x = x.long().to(device)
|
x = x.long().to(device)
|
||||||
y = y.long().to(device)
|
y = y.long().to(device)
|
||||||
|
|
||||||
pred = _forward(model, x, device)
|
pred = model(x)
|
||||||
|
|
||||||
if isinstance(model, AutoEncoder):
|
if isinstance(model, AutoEncoder):
|
||||||
|
# compare the reconstructed vector with the original
|
||||||
|
pred = pred.squeeze(1)
|
||||||
loss = loss_fn(pred, x.float() / 255.0)
|
loss = loss_fn(pred, x.float() / 255.0)
|
||||||
else:
|
else:
|
||||||
loss = loss_fn(pred, y)
|
loss = loss_fn(pred, y)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from src.models import Model
|
||||||
|
|
||||||
|
|
||||||
class Trainer(ABC):
|
class Trainer(ABC):
|
||||||
"""Abstract class for trainers."""
|
"""Abstract class for trainers."""
|
||||||
|
|
@ -10,7 +12,7 @@ class Trainer(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def execute(
|
def execute(
|
||||||
self,
|
self,
|
||||||
model: nn.Module | type[nn.Module] | None,
|
model: Model | type[Model],
|
||||||
context_length: int,
|
context_length: int,
|
||||||
train_loader: DataLoader,
|
train_loader: DataLoader,
|
||||||
validation_loader: DataLoader,
|
validation_loader: DataLoader,
|
||||||
|
|
|
||||||
Reference in a new issue