feat: autoencoder + updated trainers + cleaned up process to allow using autoencoder

This commit is contained in:
RobinMeersman 2025-12-14 14:37:04 +01:00
parent 0ab495165f
commit 17e0b52600
11 changed files with 132 additions and 211 deletions

13
main.py
View file

@ -1,5 +1,5 @@
from src.args import parse_arguments
from src.process import compress
from src.process import compress, decompress
from src.train import train
from src.utils import determine_device
@ -38,12 +38,21 @@ def main():
case 'compress':
compress(device=device,
model_name=args.model,
model_path=args.model_load_path,
input_file=args.input_file,
output_file=args.output_file,
context_length=args.context
)
case 'decompress':
decompress(
device=device,
model_name=args.model,
model_path=args.model_load_path,
input_file=args.input_file,
output_file=args.output_file,
context_length=args.context
)
case _:
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")

View file

@ -1,10 +1,8 @@
from .Model import Model
from .autoencoder import AutoEncoder
from .cnn import CNNPredictor
from .transformer import ByteTransformer
model_called: dict[str, type[Model]] = {
'cnn': CNNPredictor,
'transformer': ByteTransformer,
'autoencoder': AutoEncoder
}

View file

@ -47,9 +47,15 @@ class AutoEncoder(Model):
self.decoder = Decoder(latent_dim, channel_count, input_size)
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
return self.encoder(x)
def decode(self, x: torch.Tensor) -> torch.Tensor:
"""
x: torch.Tensor of floats
"""
return self.decoder(x)
def forward(self, x: torch.LongTensor) -> torch.Tensor:

View file

@ -1 +0,0 @@
from .transformer import ByteTransformer

View file

@ -1,70 +0,0 @@
from typing import Optional
import torch.nn as nn
from torch import Tensor, arange
from src.models import Model
class LearnedPositionalEncoding(Model):
def __init__(self, max_len, d_model):
super().__init__()
self.pos_emb = nn.Embedding(max_len, d_model)
def forward(self, x):
# x: [seq, batch, d_model]
seq_len = x.size(0)
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
return x + self.pos_emb(positions) # broadcast over batch
class ByteTransformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
layer_norm_eps=1e-05,
max_len=128
):
super().__init__()
self.src_embedding = nn.Embedding(256, d_model)
self.tgt_embedding = nn.Embedding(256, d_model)
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
dim_feedforward=dim_feedforward,
dropout=dropout,
activation=activation,
layer_norm_eps=layer_norm_eps,
batch_first=False,
norm_first=False,
device=None,
dtype=None,
)
self.output_proj = nn.Linear(d_model, 256)
self.loss_function = nn.CrossEntropyLoss()
def forward(
self,
src: Tensor,
tgt: Tensor,
) -> Tensor:
src_embeds = self.src_embedding(src)
tgt_embeds = self.tgt_embedding(tgt)
src_pos = self.src_pos(src_embeds)
tgt_pos = self.tgt_pos(tgt_embeds)
return self.output_proj(self.transformer(src_pos, tgt_pos))

View file

@ -1,11 +1,14 @@
import contextlib
import math
from collections import deque
from decimal import Decimal
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from src.models import AutoEncoder
from src.utils import reference_ae
@ -22,9 +25,74 @@ def probs_to_freqs(probs, total_freq=8192):
return freqs
def ae_compress(
output_file: str,
context_length: int,
device: str,
model: nn.Module,
byte_data: bytes,
tensor: torch.Tensor
):
# Init AE
print("Initializing AE")
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
context = deque([0] * context_length, maxlen=context_length)
# Compress
for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
with torch.inference_mode():
logits = model(context_tensor)
# normalize
mean = logits.mean(dim=-1, keepdim=True)
std = logits.std(dim=-1, keepdim=True)
logits = (logits - mean) / (std + 1e-6)
print(f"logits: {logits}")
probabilities = torch.softmax(logits[0], dim=-1)
print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
# write byte to output file
enc.write(probability_table, byte)
context.append(byte)
def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
tensor_data = torch.tensor(list(x), dtype=torch.long)
shape = tensor_data.size(0)
row_count = math.ceil(shape / context_length)
pad_count = row_count * context_length - shape
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
return tensor_data.view(row_count, context_length)
def auto_encoder_compress(
data: bytes,
model: AutoEncoder,
output_file: str,
context_length: int = 128,
device: str = "cuda"
):
# convert data to chunks of context length tensors
# send the data to device
tensor = chunk_data(data, context_length).to(device)
# compress
output = model.encode(tensor)
print(output.shape)
def compress(
device,
model_path: str,
model_name: str,
context_length: int = 128,
input_file: str | None = None,
output_file: str | None = None
@ -48,88 +116,42 @@ def compress(
model.to(device)
model.eval()
# Init AE
print("Initializing AE")
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
context = deque([0] * context_length, maxlen=context_length)
stage_min, stage_max = Decimal(0), Decimal(1)
stage = None
# Compress
for byte in tqdm(tensor.tolist(), desc="Compressing"):
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
with torch.inference_mode():
logits = model(context_tensor)
#normalize
mean = logits.mean(dim=-1, keepdim=True)
std = logits.std(dim=-1, keepdim=True)
logits = (logits - mean) / (std + 1e-6)
print(f"logits: {logits}")
probabilities = torch.softmax(logits[0], dim=-1)
print(f"probabilities: {probabilities}")
probabilities = probabilities.detach()
eps = 1e-8
# np.add(probabilities, eps)
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
# probability_table = AE.get_probability_table(frequency_table)
enc.write(probability_table, byte)
context.append(byte)
# print("Getting encoded value")
# interval_min, interval_max, _ = AE.get_encoded_value(stage)
# print("Encoding in binary")
# binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
# Pack
# val = int(binary_code, 2) if len(binary_code) else 0
# out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
# if output_file:
# print(f"Writing to {output_file}")
# with open(output_file, "w") as file:
# file.write(f"{len(byte_data)}\n")
# file.write(binary_code) # todo: temporary, decoding depends on binary string
# else:
# print(out_bytes)
match model_name:
case "cnn":
ae_compress(
output_file,
context_length,
device,
model,
byte_data,
tensor
)
case "autoencoder":
auto_encoder_compress()
case _:
raise ValueError(f"Unknown model type: {model_name}")
def bits_to_number(bits: str) -> float:
n = 0
for i, bit in enumerate(bits, start=1):
n += int(bit) / (1 << i)
return n
def ae_decompress(
def make_cumulative(probs):
cumulative = []
):
pass
total = 0
def auto_encoder_decompress(
for prob in probs:
low = total
high = total + prob
cumulative.append((low, high))
total = high
return cumulative
):
pass
def decompress(
device,
model_path: str,
model_name: str,
input_file: str,
output_file: str | None = None
output_file: str | None = None,
context_length: int = 128
):
context_length = 128
print("Reading in the data")
with open(input_file, "r") as f:
length = int(f.readline())
@ -144,30 +166,10 @@ def decompress(
model.to(device)
model.eval()
print("Decompressing")
context = deque([0] * context_length, maxlen=context_length)
output = bytearray()
x = bits_to_number(bytes_data)
for _ in range(length):
probs = model(context)
cumulative = make_cumulative(probs)
for symbol, (low, high) in enumerate(cumulative):
if low <= x < high:
break
output.append(symbol)
context.append(chr(symbol))
interval_low, interval_high = cumulative[symbol]
interval_width = interval_high - interval_low
x = (x - interval_low) / interval_width
if output_file is not None:
with open(output_file, "wb") as f:
f.write(output)
return
print(output.decode('utf-8', errors='replace'))
match model_name:
case "cnn":
ae_decompress()
case "autoencoder":
auto_encoder_decompress()
case _:
raise ValueError(f"Unknown model type: {model_name}")

View file

@ -4,7 +4,7 @@ import torch
from torch.utils.data import DataLoader
from src.dataset_loaders import dataset_called
from src.models import model_called
from src.models import model_called, Model
from src.trainers import OptunaTrainer, Trainer, FullTrainer
@ -27,10 +27,10 @@ def train(
if model_name:
print(f"Creating model: {model_name}")
model = model_called[model_name]
model: Model | type[Model] = model_called[model_name]
else:
print("Loading model from disk")
model = torch.load(model_path, weights_only=False)
model: Model | type[Model] = torch.load(model_path, weights_only=False)
dataset_common_args = {
'root': data_root,

View file

@ -12,15 +12,15 @@ class FullTrainer(Trainer):
def execute(
self,
model: Model,
model: nn.Module | type[nn.Module],
context_length: int,
train_loader: DataLoader,
validation_loader: DataLoader,
n_epochs: int | None,
device: str
) -> nn.Module:
if model is None:
raise ValueError("Model must be provided: run optuna optimizations first")
if not isinstance(model, Model):
raise ValueError("Model must be an instance for full training")
model.to(device)
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,

View file

@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
from .train import train
from .trainer import Trainer
from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder
from ..models import Model, CNNPredictor, AutoEncoder
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
@ -15,19 +15,6 @@ def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int =
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
vocab_size=256,
)
if model_cls is ByteTransformer:
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
return ByteTransformer(
d_model=context_length,
nhead=nhead,
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
)
if model_cls is AutoEncoder:
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
@ -60,7 +47,7 @@ class OptunaTrainer(Trainer):
def execute(
self,
model: type[Model],
model: Model | type[Model],
context_length,
train_loader: DataLoader,
validation_loader: DataLoader,

View file

@ -4,22 +4,7 @@ import torch
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from ..models import ByteTransformer, Model, AutoEncoder
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
if isinstance(model, ByteTransformer):
tgt_in = torch.cat([
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
x[:, :-1]
], dim=1)
logits = model(x, tgt_in)
# only consider the last time step of the model where the full context
# is available
return logits[:, -1, :]
return model(x)
from ..models import Model, AutoEncoder
def train(
model: Model,
@ -53,9 +38,10 @@ def train(
y = y.long().to(device)
optimizer.zero_grad()
pred = _forward(model, x, device)
pred = model(x)
if isinstance(model, AutoEncoder):
pred = pred.squeeze(1)
loss = loss_fn(pred, x.float() / 255.0)
else:
loss = loss_fn(pred, y)
@ -74,9 +60,11 @@ def train(
x = x.long().to(device)
y = y.long().to(device)
pred = _forward(model, x, device)
pred = model(x)
if isinstance(model, AutoEncoder):
# compare the reconstructed vector with the original
pred = pred.squeeze(1)
loss = loss_fn(pred, x.float() / 255.0)
else:
loss = loss_fn(pred, y)

View file

@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
import torch.nn as nn
from torch.utils.data import DataLoader
from src.models import Model
class Trainer(ABC):
"""Abstract class for trainers."""
@ -10,7 +12,7 @@ class Trainer(ABC):
@abstractmethod
def execute(
self,
model: nn.Module | type[nn.Module] | None,
model: Model | type[Model],
context_length: int,
train_loader: DataLoader,
validation_loader: DataLoader,