diff --git a/main.py b/main.py index 7b1aab2..1601d4e 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,5 @@ from src.args import parse_arguments -from src.process import compress +from src.process import compress, decompress from src.train import train from src.utils import determine_device @@ -38,12 +38,21 @@ def main(): case 'compress': compress(device=device, + model_name=args.model, model_path=args.model_load_path, input_file=args.input_file, output_file=args.output_file, context_length=args.context ) - + case 'decompress': + decompress( + device=device, + model_name=args.model, + model_path=args.model_load_path, + input_file=args.input_file, + output_file=args.output_file, + context_length=args.context + ) case _: raise NotImplementedError(f"Mode {args.mode} is not implemented yet") diff --git a/src/models/__init__.py b/src/models/__init__.py index dfdc5de..61ac3eb 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,10 +1,8 @@ from .Model import Model from .autoencoder import AutoEncoder from .cnn import CNNPredictor -from .transformer import ByteTransformer model_called: dict[str, type[Model]] = { 'cnn': CNNPredictor, - 'transformer': ByteTransformer, 'autoencoder': AutoEncoder } diff --git a/src/models/autoencoder/autoencoder.py b/src/models/autoencoder/autoencoder.py index 770e6f1..553e3eb 100644 --- a/src/models/autoencoder/autoencoder.py +++ b/src/models/autoencoder/autoencoder.py @@ -47,9 +47,15 @@ class AutoEncoder(Model): self.decoder = Decoder(latent_dim, channel_count, input_size) def encode(self, x: torch.Tensor) -> torch.Tensor: + """ + x: torch.Tensor of floats + """ return self.encoder(x) def decode(self, x: torch.Tensor) -> torch.Tensor: + """ + x: torch.Tensor of floats + """ return self.decoder(x) def forward(self, x: torch.LongTensor) -> torch.Tensor: diff --git a/src/models/transformer/__init__.py b/src/models/transformer/__init__.py deleted file mode 100644 index 9817800..0000000 --- a/src/models/transformer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .transformer import ByteTransformer \ No newline at end of file diff --git a/src/models/transformer/transformer.py b/src/models/transformer/transformer.py deleted file mode 100644 index f85e60d..0000000 --- a/src/models/transformer/transformer.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Optional - -import torch.nn as nn -from torch import Tensor, arange - -from src.models import Model - - -class LearnedPositionalEncoding(Model): - def __init__(self, max_len, d_model): - super().__init__() - self.pos_emb = nn.Embedding(max_len, d_model) - - def forward(self, x): - # x: [seq, batch, d_model] - seq_len = x.size(0) - positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1] - return x + self.pos_emb(positions) # broadcast over batch - -class ByteTransformer(nn.Module): - def __init__( - self, - d_model=512, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - dim_feedforward=2048, - dropout=0.1, - activation="relu", - layer_norm_eps=1e-05, - max_len=128 - ): - super().__init__() - self.src_embedding = nn.Embedding(256, d_model) - self.tgt_embedding = nn.Embedding(256, d_model) - - self.src_pos = LearnedPositionalEncoding(max_len, d_model) - self.tgt_pos = LearnedPositionalEncoding(max_len, d_model) - - self.transformer = nn.Transformer( - d_model=d_model, - nhead=nhead, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dim_feedforward=dim_feedforward, - dropout=dropout, - activation=activation, - layer_norm_eps=layer_norm_eps, - batch_first=False, - norm_first=False, - device=None, - dtype=None, - ) - - self.output_proj = nn.Linear(d_model, 256) - - self.loss_function = nn.CrossEntropyLoss() - - def forward( - self, - src: Tensor, - tgt: Tensor, - ) -> Tensor: - src_embeds = self.src_embedding(src) - tgt_embeds = self.tgt_embedding(tgt) - - src_pos = self.src_pos(src_embeds) - tgt_pos = self.tgt_pos(tgt_embeds) - - return self.output_proj(self.transformer(src_pos, tgt_pos)) diff --git a/src/process.py b/src/process.py index 31de886..9598e67 100644 --- a/src/process.py +++ b/src/process.py @@ -1,11 +1,14 @@ import contextlib +import math from collections import deque from decimal import Decimal import numpy as np import torch +import torch.nn as nn from tqdm import tqdm +from src.models import AutoEncoder from src.utils import reference_ae @@ -22,9 +25,74 @@ def probs_to_freqs(probs, total_freq=8192): return freqs +def ae_compress( + output_file: str, + context_length: int, + device: str, + model: nn.Module, + byte_data: bytes, + tensor: torch.Tensor + +): + # Init AE + print("Initializing AE") + + with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout: + enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout) + + context = deque([0] * context_length, maxlen=context_length) + + # Compress + for byte in tqdm(tensor.tolist(), desc="Compressing"): + context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) + + with torch.inference_mode(): + logits = model(context_tensor) + # normalize + mean = logits.mean(dim=-1, keepdim=True) + std = logits.std(dim=-1, keepdim=True) + logits = (logits - mean) / (std + 1e-6) + print(f"logits: {logits}") + probabilities = torch.softmax(logits[0], dim=-1) + print(f"probabilities: {probabilities}") + probabilities = probabilities.detach() + probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities)) + + # write byte to output file + enc.write(probability_table, byte) + + context.append(byte) + +def chunk_data(x: bytes, context_length = 128) -> torch.Tensor: + tensor_data = torch.tensor(list(x), dtype=torch.long) + shape = tensor_data.size(0) + row_count = math.ceil(shape / context_length) + pad_count = row_count * context_length - shape + tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0) + return tensor_data.view(row_count, context_length) + +def auto_encoder_compress( + data: bytes, + model: AutoEncoder, + output_file: str, + context_length: int = 128, + device: str = "cuda" +): + # convert data to chunks of context length tensors + # send the data to device + tensor = chunk_data(data, context_length).to(device) + + # compress + output = model.encode(tensor) + print(output.shape) + + + + def compress( device, model_path: str, + model_name: str, context_length: int = 128, input_file: str | None = None, output_file: str | None = None @@ -48,88 +116,42 @@ def compress( model.to(device) model.eval() - # Init AE - print("Initializing AE") - - with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout: - enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout) - - context = deque([0] * context_length, maxlen=context_length) - - stage_min, stage_max = Decimal(0), Decimal(1) - stage = None - - # Compress - for byte in tqdm(tensor.tolist(), desc="Compressing"): - context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device) - - with torch.inference_mode(): - logits = model(context_tensor) - #normalize - mean = logits.mean(dim=-1, keepdim=True) - std = logits.std(dim=-1, keepdim=True) - logits = (logits - mean) / (std + 1e-6) - print(f"logits: {logits}") - probabilities = torch.softmax(logits[0], dim=-1) - print(f"probabilities: {probabilities}") - probabilities = probabilities.detach() - - eps = 1e-8 - # np.add(probabilities, eps) - # frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))} - probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities)) - # probability_table = AE.get_probability_table(frequency_table) - - enc.write(probability_table, byte) - - context.append(byte) - - # print("Getting encoded value") - # interval_min, interval_max, _ = AE.get_encoded_value(stage) - # print("Encoding in binary") - # binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max) - - # Pack - # val = int(binary_code, 2) if len(binary_code) else 0 - # out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big") - - # if output_file: - # print(f"Writing to {output_file}") - # with open(output_file, "w") as file: - # file.write(f"{len(byte_data)}\n") - # file.write(binary_code) # todo: temporary, decoding depends on binary string - # else: - # print(out_bytes) + match model_name: + case "cnn": + ae_compress( + output_file, + context_length, + device, + model, + byte_data, + tensor + ) + case "autoencoder": + auto_encoder_compress() + case _: + raise ValueError(f"Unknown model type: {model_name}") -def bits_to_number(bits: str) -> float: - n = 0 - for i, bit in enumerate(bits, start=1): - n += int(bit) / (1 << i) - return n +def ae_decompress( -def make_cumulative(probs): - cumulative = [] +): + pass - total = 0 +def auto_encoder_decompress( - for prob in probs: - low = total - high = total + prob - cumulative.append((low, high)) - total = high - return cumulative +): + pass def decompress( device, model_path: str, + model_name: str, input_file: str, - output_file: str | None = None + output_file: str | None = None, + context_length: int = 128 ): - context_length = 128 - print("Reading in the data") with open(input_file, "r") as f: length = int(f.readline()) @@ -144,30 +166,10 @@ def decompress( model.to(device) model.eval() - print("Decompressing") - context = deque([0] * context_length, maxlen=context_length) - output = bytearray() - - x = bits_to_number(bytes_data) - - for _ in range(length): - probs = model(context) - cumulative = make_cumulative(probs) - - for symbol, (low, high) in enumerate(cumulative): - if low <= x < high: - break - - output.append(symbol) - context.append(chr(symbol)) - - interval_low, interval_high = cumulative[symbol] - interval_width = interval_high - interval_low - x = (x - interval_low) / interval_width - - if output_file is not None: - with open(output_file, "wb") as f: - f.write(output) - return - - print(output.decode('utf-8', errors='replace')) + match model_name: + case "cnn": + ae_decompress() + case "autoencoder": + auto_encoder_decompress() + case _: + raise ValueError(f"Unknown model type: {model_name}") diff --git a/src/train.py b/src/train.py index 917ef2a..cd66745 100644 --- a/src/train.py +++ b/src/train.py @@ -4,7 +4,7 @@ import torch from torch.utils.data import DataLoader from src.dataset_loaders import dataset_called -from src.models import model_called +from src.models import model_called, Model from src.trainers import OptunaTrainer, Trainer, FullTrainer @@ -27,10 +27,10 @@ def train( if model_name: print(f"Creating model: {model_name}") - model = model_called[model_name] + model: Model | type[Model] = model_called[model_name] else: print("Loading model from disk") - model = torch.load(model_path, weights_only=False) + model: Model | type[Model] = torch.load(model_path, weights_only=False) dataset_common_args = { 'root': data_root, diff --git a/src/trainers/FullTrainer.py b/src/trainers/FullTrainer.py index a94ac5c..205d177 100644 --- a/src/trainers/FullTrainer.py +++ b/src/trainers/FullTrainer.py @@ -12,15 +12,15 @@ class FullTrainer(Trainer): def execute( self, - model: Model, + model: nn.Module | type[nn.Module], context_length: int, train_loader: DataLoader, validation_loader: DataLoader, n_epochs: int | None, device: str ) -> nn.Module: - if model is None: - raise ValueError("Model must be provided: run optuna optimizations first") + if not isinstance(model, Model): + raise ValueError("Model must be an instance for full training") model.to(device) train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs, diff --git a/src/trainers/OptunaTrainer.py b/src/trainers/OptunaTrainer.py index fa39ea1..81e33da 100644 --- a/src/trainers/OptunaTrainer.py +++ b/src/trainers/OptunaTrainer.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader from .train import train from .trainer import Trainer -from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder +from ..models import Model, CNNPredictor, AutoEncoder def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128): @@ -15,19 +15,6 @@ def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True), vocab_size=256, ) - if model_cls is ByteTransformer: - nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2 - # d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead) - return ByteTransformer( - d_model=context_length, - nhead=nhead, - num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True), - num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True), - dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True), - dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True), - activation=trial.suggest_categorical("activation", ["relu", "gelu"]), - layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True), - ) if model_cls is AutoEncoder: channel_count = trial.suggest_int("channel_count", 1, 8, log=True) latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True) @@ -60,7 +47,7 @@ class OptunaTrainer(Trainer): def execute( self, - model: type[Model], + model: Model | type[Model], context_length, train_loader: DataLoader, validation_loader: DataLoader, diff --git a/src/trainers/train.py b/src/trainers/train.py index ac05d27..6edb053 100644 --- a/src/trainers/train.py +++ b/src/trainers/train.py @@ -4,22 +4,7 @@ import torch from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from ..models import ByteTransformer, Model, AutoEncoder - - -def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor: - if isinstance(model, ByteTransformer): - tgt_in = torch.cat([ - torch.zeros(x.shape[0], 1, device=device, dtype=torch.long), - x[:, :-1] - ], dim=1) - logits = model(x, tgt_in) - - # only consider the last time step of the model where the full context - # is available - return logits[:, -1, :] - return model(x) - +from ..models import Model, AutoEncoder def train( model: Model, @@ -53,9 +38,10 @@ def train( y = y.long().to(device) optimizer.zero_grad() - pred = _forward(model, x, device) + pred = model(x) if isinstance(model, AutoEncoder): + pred = pred.squeeze(1) loss = loss_fn(pred, x.float() / 255.0) else: loss = loss_fn(pred, y) @@ -74,9 +60,11 @@ def train( x = x.long().to(device) y = y.long().to(device) - pred = _forward(model, x, device) + pred = model(x) if isinstance(model, AutoEncoder): + # compare the reconstructed vector with the original + pred = pred.squeeze(1) loss = loss_fn(pred, x.float() / 255.0) else: loss = loss_fn(pred, y) diff --git a/src/trainers/trainer.py b/src/trainers/trainer.py index 2a9b99e..1af7ad4 100644 --- a/src/trainers/trainer.py +++ b/src/trainers/trainer.py @@ -3,6 +3,8 @@ from abc import ABC, abstractmethod import torch.nn as nn from torch.utils.data import DataLoader +from src.models import Model + class Trainer(ABC): """Abstract class for trainers.""" @@ -10,7 +12,7 @@ class Trainer(ABC): @abstractmethod def execute( self, - model: nn.Module | type[nn.Module] | None, + model: Model | type[Model], context_length: int, train_loader: DataLoader, validation_loader: DataLoader,