feat: autoencoder + updated trainers + cleaned up process to allow using autoencoder

This commit is contained in:
RobinMeersman 2025-12-14 14:37:04 +01:00
parent 0ab495165f
commit 17e0b52600
11 changed files with 132 additions and 211 deletions

13
main.py
View file

@ -1,5 +1,5 @@
from src.args import parse_arguments from src.args import parse_arguments
from src.process import compress from src.process import compress, decompress
from src.train import train from src.train import train
from src.utils import determine_device from src.utils import determine_device
@ -38,12 +38,21 @@ def main():
case 'compress': case 'compress':
compress(device=device, compress(device=device,
model_name=args.model,
model_path=args.model_load_path,
input_file=args.input_file,
output_file=args.output_file,
context_length=args.context
)
case 'decompress':
decompress(
device=device,
model_name=args.model,
model_path=args.model_load_path, model_path=args.model_load_path,
input_file=args.input_file, input_file=args.input_file,
output_file=args.output_file, output_file=args.output_file,
context_length=args.context context_length=args.context
) )
case _: case _:
raise NotImplementedError(f"Mode {args.mode} is not implemented yet") raise NotImplementedError(f"Mode {args.mode} is not implemented yet")

View file

@ -1,10 +1,8 @@
from .Model import Model from .Model import Model
from .autoencoder import AutoEncoder from .autoencoder import AutoEncoder
from .cnn import CNNPredictor from .cnn import CNNPredictor
from .transformer import ByteTransformer
model_called: dict[str, type[Model]] = { model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor, 'cnn': CNNPredictor,
'transformer': ByteTransformer,
'autoencoder': AutoEncoder 'autoencoder': AutoEncoder
} }

View file

@ -47,9 +47,15 @@ class AutoEncoder(Model):
self.decoder = Decoder(latent_dim, channel_count, input_size) self.decoder = Decoder(latent_dim, channel_count, input_size)
def encode(self, x: torch.Tensor) -> torch.Tensor: def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
return self.encoder(x) return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor: def decode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
return self.decoder(x) return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor: def forward(self, x: torch.LongTensor) -> torch.Tensor:

View file

@ -1 +0,0 @@
from .transformer import ByteTransformer

View file

@ -1,70 +0,0 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor, arange
from src.models import Model
class LearnedPositionalEncoding(Model):
def __init__(self, max_len, d_model):
super().__init__()
self.pos_emb = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: [seq, batch, d_model]
seq_len = x.size(0)
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
return x + self.pos_emb(positions) # broadcast over batch
class ByteTransformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-05,
max_len=128
):
super().__init__()
self.src_embedding = nn.Embedding(256, d_model)
self.tgt_embedding = nn.Embedding(256, d_model)
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=False,
norm_first=False,
device=None,
dtype=None,
)
self.output_proj = nn.Linear(d_model, 256)
self.loss_function = nn.CrossEntropyLoss()
def forward(
self,
src: Tensor,
tgt: Tensor,
) -> Tensor:
src_embeds = self.src_embedding(src)
tgt_embeds = self.tgt_embedding(tgt)
src_pos = self.src_pos(src_embeds)
tgt_pos = self.tgt_pos(tgt_embeds)
return self.output_proj(self.transformer(src_pos, tgt_pos))

View file

@ -1,11 +1,14 @@
import contextlib import contextlib
import math
from collections import deque from collections import deque
from decimal import Decimal from decimal import Decimal
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from src.models import AutoEncoder
from src.utils import reference_ae from src.utils import reference_ae
@ -22,9 +25,74 @@ def probs_to_freqs(probs, total_freq=8192):
return freqs return freqs
def ae_compress(
output_file: str,
context_length: int,
device: str,
model: nn.Module,
byte_data: bytes,
tensor: torch.Tensor
):
# Init AE
print("Initializing AE")
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
context = deque([0] * context_length, maxlen=context_length)
# Compress
for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
with torch.inference_mode():
logits = model(context_tensor)
# normalize
mean = logits.mean(dim=-1, keepdim=True)
std = logits.std(dim=-1, keepdim=True)
logits = (logits - mean) / (std + 1e-6)
print(f"logits: {logits}")
probabilities = torch.softmax(logits[0], dim=-1)
print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
# write byte to output file
enc.write(probability_table, byte)
context.append(byte)
def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
tensor_data = torch.tensor(list(x), dtype=torch.long)
shape = tensor_data.size(0)
row_count = math.ceil(shape / context_length)
pad_count = row_count * context_length - shape
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
return tensor_data.view(row_count, context_length)
def auto_encoder_compress(
data: bytes,
model: AutoEncoder,
output_file: str,
context_length: int = 128,
device: str = "cuda"
):
# convert data to chunks of context length tensors
# send the data to device
tensor = chunk_data(data, context_length).to(device)
# compress
output = model.encode(tensor)
print(output.shape)
def compress( def compress(
device, device,
model_path: str, model_path: str,
model_name: str,
context_length: int = 128, context_length: int = 128,
input_file: str | None = None, input_file: str | None = None,
output_file: str | None = None output_file: str | None = None
@ -48,88 +116,42 @@ def compress(
model.to(device) model.to(device)
model.eval() model.eval()
# Init AE match model_name:
print("Initializing AE") case "cnn":
ae_compress(
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout: output_file,
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout) context_length,
device,
context = deque([0] * context_length, maxlen=context_length) model,
byte_data,
stage_min, stage_max = Decimal(0), Decimal(1) tensor
stage = None )
case "autoencoder":
# Compress auto_encoder_compress()
for byte in tqdm(tensor.tolist(), desc="Compressing"): case _:
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) raise ValueError(f"Unknown model type: {model_name}")
with torch.inference_mode():
logits = model(context_tensor)
#normalize
mean = logits.mean(dim=-1, keepdim=True)
std = logits.std(dim=-1, keepdim=True)
logits = (logits - mean) / (std + 1e-6)
print(f"logits: {logits}")
probabilities = torch.softmax(logits[0], dim=-1)
print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
eps = 1e-8
# np.add(probabilities, eps)
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
# probability_table = AE.get_probability_table(frequency_table)
enc.write(probability_table, byte)
context.append(byte)
# print("Getting encoded value")
# interval_min, interval_max, _ = AE.get_encoded_value(stage)
# print("Encoding in binary")
# binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
# Pack
# val = int(binary_code, 2) if len(binary_code) else 0
# out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
# if output_file:
# print(f"Writing to {output_file}")
# with open(output_file, "w") as file:
# file.write(f"{len(byte_data)}\n")
# file.write(binary_code) # todo: temporary, decoding depends on binary string
# else:
# print(out_bytes)
def bits_to_number(bits: str) -> float:
n = 0
for i, bit in enumerate(bits, start=1):
n += int(bit) / (1 << i)
return n
def ae_decompress(
def make_cumulative(probs): ):
cumulative = [] pass
total = 0 def auto_encoder_decompress(
for prob in probs: ):
low = total pass
high = total + prob
cumulative.append((low, high))
total = high
return cumulative
def decompress( def decompress(
device, device,
model_path: str, model_path: str,
model_name: str,
input_file: str, input_file: str,
output_file: str | None = None output_file: str | None = None,
context_length: int = 128
): ):
context_length = 128
print("Reading in the data") print("Reading in the data")
with open(input_file, "r") as f: with open(input_file, "r") as f:
length = int(f.readline()) length = int(f.readline())
@ -144,30 +166,10 @@ def decompress(
model.to(device) model.to(device)
model.eval() model.eval()
print("Decompressing") match model_name:
context = deque([0] * context_length, maxlen=context_length) case "cnn":
output = bytearray() ae_decompress()
case "autoencoder":
x = bits_to_number(bytes_data) auto_encoder_decompress()
case _:
for _ in range(length): raise ValueError(f"Unknown model type: {model_name}")
probs = model(context)
cumulative = make_cumulative(probs)
for symbol, (low, high) in enumerate(cumulative):
if low <= x < high:
break
output.append(symbol)
context.append(chr(symbol))
interval_low, interval_high = cumulative[symbol]
interval_width = interval_high - interval_low
x = (x - interval_low) / interval_width
if output_file is not None:
with open(output_file, "wb") as f:
f.write(output)
return
print(output.decode('utf-8', errors='replace'))

View file

@ -4,7 +4,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called from src.dataset_loaders import dataset_called
from src.models import model_called from src.models import model_called, Model
from src.trainers import OptunaTrainer, Trainer, FullTrainer from src.trainers import OptunaTrainer, Trainer, FullTrainer
@ -27,10 +27,10 @@ def train(
if model_name: if model_name:
print(f"Creating model: {model_name}") print(f"Creating model: {model_name}")
model = model_called[model_name] model: Model | type[Model] = model_called[model_name]
else: else:
print("Loading model from disk") print("Loading model from disk")
model = torch.load(model_path, weights_only=False) model: Model | type[Model] = torch.load(model_path, weights_only=False)
dataset_common_args = { dataset_common_args = {
'root': data_root, 'root': data_root,

View file

@ -12,15 +12,15 @@ class FullTrainer(Trainer):
def execute( def execute(
self, self,
model: Model, model: nn.Module | type[nn.Module],
context_length: int, context_length: int,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,
n_epochs: int | None, n_epochs: int | None,
device: str device: str
) -> nn.Module: ) -> nn.Module:
if model is None: if not isinstance(model, Model):
raise ValueError("Model must be provided: run optuna optimizations first") raise ValueError("Model must be an instance for full training")
model.to(device) model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,

View file

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
from .train import train from .train import train
from .trainer import Trainer from .trainer import Trainer
from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder from ..models import Model, CNNPredictor, AutoEncoder
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128): def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
@ -15,19 +15,6 @@ def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int =
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256, vocab_size=256,
) )
if model_cls is ByteTransformer:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return ByteTransformer(
d_model=context_length,
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
if model_cls is AutoEncoder: if model_cls is AutoEncoder:
channel_count = trial.suggest_int("channel_count", 1, 8, log=True) channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True) latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
@ -60,7 +47,7 @@ class OptunaTrainer(Trainer):
def execute( def execute(
self, self,
model: type[Model], model: Model | type[Model],
context_length, context_length,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,

View file

@ -4,22 +4,7 @@ import torch
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from ..models import ByteTransformer, Model, AutoEncoder from ..models import Model, AutoEncoder
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
if isinstance(model, ByteTransformer):
tgt_in = torch.cat([
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
x[:, :-1]
], dim=1)
logits = model(x, tgt_in)
# only consider the last time step of the model where the full context
# is available
return logits[:, -1, :]
return model(x)
def train( def train(
model: Model, model: Model,
@ -53,9 +38,10 @@ def train(
y = y.long().to(device) y = y.long().to(device)
optimizer.zero_grad() optimizer.zero_grad()
pred = _forward(model, x, device) pred = model(x)
if isinstance(model, AutoEncoder): if isinstance(model, AutoEncoder):
pred = pred.squeeze(1)
loss = loss_fn(pred, x.float() / 255.0) loss = loss_fn(pred, x.float() / 255.0)
else: else:
loss = loss_fn(pred, y) loss = loss_fn(pred, y)
@ -74,9 +60,11 @@ def train(
x = x.long().to(device) x = x.long().to(device)
y = y.long().to(device) y = y.long().to(device)
pred = _forward(model, x, device) pred = model(x)
if isinstance(model, AutoEncoder): if isinstance(model, AutoEncoder):
# compare the reconstructed vector with the original
pred = pred.squeeze(1)
loss = loss_fn(pred, x.float() / 255.0) loss = loss_fn(pred, x.float() / 255.0)
else: else:
loss = loss_fn(pred, y) loss = loss_fn(pred, y)

View file

@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from src.models import Model
class Trainer(ABC): class Trainer(ABC):
"""Abstract class for trainers.""" """Abstract class for trainers."""
@ -10,7 +12,7 @@ class Trainer(ABC):
@abstractmethod @abstractmethod
def execute( def execute(
self, self,
model: nn.Module | type[nn.Module] | None, model: Model | type[Model],
context_length: int, context_length: int,
train_loader: DataLoader, train_loader: DataLoader,
validation_loader: DataLoader, validation_loader: DataLoader,