feat: autoencoder + updated trainers + cleaned up process to allow using autoencoder
This commit is contained in:
parent
0ab495165f
commit
17e0b52600
11 changed files with 132 additions and 211 deletions
13
main.py
13
main.py
|
|
@ -1,5 +1,5 @@
|
|||
from src.args import parse_arguments
|
||||
from src.process import compress
|
||||
from src.process import compress, decompress
|
||||
from src.train import train
|
||||
from src.utils import determine_device
|
||||
|
||||
|
|
@ -38,12 +38,21 @@ def main():
|
|||
|
||||
case 'compress':
|
||||
compress(device=device,
|
||||
model_name=args.model,
|
||||
model_path=args.model_load_path,
|
||||
input_file=args.input_file,
|
||||
output_file=args.output_file,
|
||||
context_length=args.context
|
||||
)
|
||||
case 'decompress':
|
||||
decompress(
|
||||
device=device,
|
||||
model_name=args.model,
|
||||
model_path=args.model_load_path,
|
||||
input_file=args.input_file,
|
||||
output_file=args.output_file,
|
||||
context_length=args.context
|
||||
)
|
||||
|
||||
case _:
|
||||
raise NotImplementedError(f"Mode {args.mode} is not implemented yet")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,8 @@
|
|||
from .Model import Model
|
||||
from .autoencoder import AutoEncoder
|
||||
from .cnn import CNNPredictor
|
||||
from .transformer import ByteTransformer
|
||||
|
||||
model_called: dict[str, type[Model]] = {
|
||||
'cnn': CNNPredictor,
|
||||
'transformer': ByteTransformer,
|
||||
'autoencoder': AutoEncoder
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,9 +47,15 @@ class AutoEncoder(Model):
|
|||
self.decoder = Decoder(latent_dim, channel_count, input_size)
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
x: torch.Tensor of floats
|
||||
"""
|
||||
return self.decoder(x)
|
||||
|
||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
from .transformer import ByteTransformer
|
||||
|
|
@ -1,70 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, arange
|
||||
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class LearnedPositionalEncoding(Model):
|
||||
def __init__(self, max_len, d_model):
|
||||
super().__init__()
|
||||
self.pos_emb = nn.Embedding(max_len, d_model)
|
||||
|
||||
def forward(self, x):
|
||||
# x: [seq, batch, d_model]
|
||||
seq_len = x.size(0)
|
||||
positions = arange(seq_len, device=x.device).unsqueeze(1) # [seq, 1]
|
||||
return x + self.pos_emb(positions) # broadcast over batch
|
||||
|
||||
class ByteTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
nhead=8,
|
||||
num_encoder_layers=6,
|
||||
num_decoder_layers=6,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
layer_norm_eps=1e-05,
|
||||
max_len=128
|
||||
):
|
||||
super().__init__()
|
||||
self.src_embedding = nn.Embedding(256, d_model)
|
||||
self.tgt_embedding = nn.Embedding(256, d_model)
|
||||
|
||||
self.src_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
self.tgt_pos = LearnedPositionalEncoding(max_len, d_model)
|
||||
|
||||
self.transformer = nn.Transformer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
activation=activation,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
batch_first=False,
|
||||
norm_first=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
)
|
||||
|
||||
self.output_proj = nn.Linear(d_model, 256)
|
||||
|
||||
self.loss_function = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
tgt: Tensor,
|
||||
) -> Tensor:
|
||||
src_embeds = self.src_embedding(src)
|
||||
tgt_embeds = self.tgt_embedding(tgt)
|
||||
|
||||
src_pos = self.src_pos(src_embeds)
|
||||
tgt_pos = self.tgt_pos(tgt_embeds)
|
||||
|
||||
return self.output_proj(self.transformer(src_pos, tgt_pos))
|
||||
194
src/process.py
194
src/process.py
|
|
@ -1,11 +1,14 @@
|
|||
import contextlib
|
||||
import math
|
||||
from collections import deque
|
||||
from decimal import Decimal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from src.models import AutoEncoder
|
||||
from src.utils import reference_ae
|
||||
|
||||
|
||||
|
|
@ -22,9 +25,74 @@ def probs_to_freqs(probs, total_freq=8192):
|
|||
return freqs
|
||||
|
||||
|
||||
def ae_compress(
|
||||
output_file: str,
|
||||
context_length: int,
|
||||
device: str,
|
||||
model: nn.Module,
|
||||
byte_data: bytes,
|
||||
tensor: torch.Tensor
|
||||
|
||||
):
|
||||
# Init AE
|
||||
print("Initializing AE")
|
||||
|
||||
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
||||
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
||||
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
|
||||
# Compress
|
||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
# normalize
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
std = logits.std(dim=-1, keepdim=True)
|
||||
logits = (logits - mean) / (std + 1e-6)
|
||||
print(f"logits: {logits}")
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
print(f"probabilities: {probabilities}")
|
||||
probabilities = probabilities.detach()
|
||||
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
||||
|
||||
# write byte to output file
|
||||
enc.write(probability_table, byte)
|
||||
|
||||
context.append(byte)
|
||||
|
||||
def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
|
||||
tensor_data = torch.tensor(list(x), dtype=torch.long)
|
||||
shape = tensor_data.size(0)
|
||||
row_count = math.ceil(shape / context_length)
|
||||
pad_count = row_count * context_length - shape
|
||||
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
|
||||
return tensor_data.view(row_count, context_length)
|
||||
|
||||
def auto_encoder_compress(
|
||||
data: bytes,
|
||||
model: AutoEncoder,
|
||||
output_file: str,
|
||||
context_length: int = 128,
|
||||
device: str = "cuda"
|
||||
):
|
||||
# convert data to chunks of context length tensors
|
||||
# send the data to device
|
||||
tensor = chunk_data(data, context_length).to(device)
|
||||
|
||||
# compress
|
||||
output = model.encode(tensor)
|
||||
print(output.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
def compress(
|
||||
device,
|
||||
model_path: str,
|
||||
model_name: str,
|
||||
context_length: int = 128,
|
||||
input_file: str | None = None,
|
||||
output_file: str | None = None
|
||||
|
|
@ -48,88 +116,42 @@ def compress(
|
|||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
# Init AE
|
||||
print("Initializing AE")
|
||||
|
||||
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
||||
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
||||
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
|
||||
stage_min, stage_max = Decimal(0), Decimal(1)
|
||||
stage = None
|
||||
|
||||
# Compress
|
||||
for byte in tqdm(tensor.tolist(), desc="Compressing"):
|
||||
context_tensor = torch.tensor([list(context)], dtype=torch.long, device=device)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = model(context_tensor)
|
||||
#normalize
|
||||
mean = logits.mean(dim=-1, keepdim=True)
|
||||
std = logits.std(dim=-1, keepdim=True)
|
||||
logits = (logits - mean) / (std + 1e-6)
|
||||
print(f"logits: {logits}")
|
||||
probabilities = torch.softmax(logits[0], dim=-1)
|
||||
print(f"probabilities: {probabilities}")
|
||||
probabilities = probabilities.detach()
|
||||
|
||||
eps = 1e-8
|
||||
# np.add(probabilities, eps)
|
||||
# frequency_table = {i: float(probabilities[i]) + eps for i in range(len(probabilities))}
|
||||
probability_table = reference_ae.SimpleFrequencyTable(probs_to_freqs(probabilities))
|
||||
# probability_table = AE.get_probability_table(frequency_table)
|
||||
|
||||
enc.write(probability_table, byte)
|
||||
|
||||
context.append(byte)
|
||||
|
||||
# print("Getting encoded value")
|
||||
# interval_min, interval_max, _ = AE.get_encoded_value(stage)
|
||||
# print("Encoding in binary")
|
||||
# binary_code, _ = AE.custom_binary_encoding(interval_min, interval_max)
|
||||
|
||||
# Pack
|
||||
# val = int(binary_code, 2) if len(binary_code) else 0
|
||||
# out_bytes = val.to_bytes((len(binary_code) + 7) // 8, "big")
|
||||
|
||||
# if output_file:
|
||||
# print(f"Writing to {output_file}")
|
||||
# with open(output_file, "w") as file:
|
||||
# file.write(f"{len(byte_data)}\n")
|
||||
# file.write(binary_code) # todo: temporary, decoding depends on binary string
|
||||
# else:
|
||||
# print(out_bytes)
|
||||
match model_name:
|
||||
case "cnn":
|
||||
ae_compress(
|
||||
output_file,
|
||||
context_length,
|
||||
device,
|
||||
model,
|
||||
byte_data,
|
||||
tensor
|
||||
)
|
||||
case "autoencoder":
|
||||
auto_encoder_compress()
|
||||
case _:
|
||||
raise ValueError(f"Unknown model type: {model_name}")
|
||||
|
||||
|
||||
def bits_to_number(bits: str) -> float:
|
||||
n = 0
|
||||
for i, bit in enumerate(bits, start=1):
|
||||
n += int(bit) / (1 << i)
|
||||
return n
|
||||
|
||||
def ae_decompress(
|
||||
|
||||
def make_cumulative(probs):
|
||||
cumulative = []
|
||||
):
|
||||
pass
|
||||
|
||||
total = 0
|
||||
def auto_encoder_decompress(
|
||||
|
||||
for prob in probs:
|
||||
low = total
|
||||
high = total + prob
|
||||
cumulative.append((low, high))
|
||||
total = high
|
||||
return cumulative
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
def decompress(
|
||||
device,
|
||||
model_path: str,
|
||||
model_name: str,
|
||||
input_file: str,
|
||||
output_file: str | None = None
|
||||
output_file: str | None = None,
|
||||
context_length: int = 128
|
||||
):
|
||||
context_length = 128
|
||||
|
||||
print("Reading in the data")
|
||||
with open(input_file, "r") as f:
|
||||
length = int(f.readline())
|
||||
|
|
@ -144,30 +166,10 @@ def decompress(
|
|||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
print("Decompressing")
|
||||
context = deque([0] * context_length, maxlen=context_length)
|
||||
output = bytearray()
|
||||
|
||||
x = bits_to_number(bytes_data)
|
||||
|
||||
for _ in range(length):
|
||||
probs = model(context)
|
||||
cumulative = make_cumulative(probs)
|
||||
|
||||
for symbol, (low, high) in enumerate(cumulative):
|
||||
if low <= x < high:
|
||||
break
|
||||
|
||||
output.append(symbol)
|
||||
context.append(chr(symbol))
|
||||
|
||||
interval_low, interval_high = cumulative[symbol]
|
||||
interval_width = interval_high - interval_low
|
||||
x = (x - interval_low) / interval_width
|
||||
|
||||
if output_file is not None:
|
||||
with open(output_file, "wb") as f:
|
||||
f.write(output)
|
||||
return
|
||||
|
||||
print(output.decode('utf-8', errors='replace'))
|
||||
match model_name:
|
||||
case "cnn":
|
||||
ae_decompress()
|
||||
case "autoencoder":
|
||||
auto_encoder_decompress()
|
||||
case _:
|
||||
raise ValueError(f"Unknown model type: {model_name}")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.dataset_loaders import dataset_called
|
||||
from src.models import model_called
|
||||
from src.models import model_called, Model
|
||||
from src.trainers import OptunaTrainer, Trainer, FullTrainer
|
||||
|
||||
|
||||
|
|
@ -27,10 +27,10 @@ def train(
|
|||
|
||||
if model_name:
|
||||
print(f"Creating model: {model_name}")
|
||||
model = model_called[model_name]
|
||||
model: Model | type[Model] = model_called[model_name]
|
||||
else:
|
||||
print("Loading model from disk")
|
||||
model = torch.load(model_path, weights_only=False)
|
||||
model: Model | type[Model] = torch.load(model_path, weights_only=False)
|
||||
|
||||
dataset_common_args = {
|
||||
'root': data_root,
|
||||
|
|
|
|||
|
|
@ -12,15 +12,15 @@ class FullTrainer(Trainer):
|
|||
|
||||
def execute(
|
||||
self,
|
||||
model: Model,
|
||||
model: nn.Module | type[nn.Module],
|
||||
context_length: int,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
n_epochs: int | None,
|
||||
device: str
|
||||
) -> nn.Module:
|
||||
if model is None:
|
||||
raise ValueError("Model must be provided: run optuna optimizations first")
|
||||
if not isinstance(model, Model):
|
||||
raise ValueError("Model must be an instance for full training")
|
||||
|
||||
model.to(device)
|
||||
train_loss, val_loss = train(model, train_loader, validation_loader, model.loss_function, n_epochs,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from .train import train
|
||||
from .trainer import Trainer
|
||||
from ..models import Model, CNNPredictor, ByteTransformer, AutoEncoder
|
||||
from ..models import Model, CNNPredictor, AutoEncoder
|
||||
|
||||
|
||||
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
|
||||
|
|
@ -15,19 +15,6 @@ def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int =
|
|||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
||||
vocab_size=256,
|
||||
)
|
||||
if model_cls is ByteTransformer:
|
||||
nhead = trial.suggest_categorical("nhead", [2, 4, 8]) # Only powers of 2
|
||||
# d_model_dim = nhead * trial.suggest_int("d_model_mult", 64 // nhead, 512 // nhead)
|
||||
return ByteTransformer(
|
||||
d_model=context_length,
|
||||
nhead=nhead,
|
||||
num_encoder_layers=trial.suggest_int("num_encoder_layers", 2, 6, log=True),
|
||||
num_decoder_layers=trial.suggest_int("num_decoder_layers", 2, 6, log=True),
|
||||
dim_feedforward=trial.suggest_int("dim_feedforward", 64, 512, log=True),
|
||||
dropout=trial.suggest_float("dropout", 0.01, 0.5, log=True),
|
||||
activation=trial.suggest_categorical("activation", ["relu", "gelu"]),
|
||||
layer_norm_eps=trial.suggest_float("layer_norm_eps", 1e-8, 1e-6, log=True),
|
||||
)
|
||||
if model_cls is AutoEncoder:
|
||||
channel_count = trial.suggest_int("channel_count", 1, 8, log=True)
|
||||
latent_dim = trial.suggest_int("latent_dim", 32, 64, log=True)
|
||||
|
|
@ -60,7 +47,7 @@ class OptunaTrainer(Trainer):
|
|||
|
||||
def execute(
|
||||
self,
|
||||
model: type[Model],
|
||||
model: Model | type[Model],
|
||||
context_length,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
|
|
|
|||
|
|
@ -4,22 +4,7 @@ import torch
|
|||
from torch.utils.data.dataloader import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..models import ByteTransformer, Model, AutoEncoder
|
||||
|
||||
|
||||
def _forward(model: Model, x: torch.Tensor, device: str) -> torch.Tensor:
|
||||
if isinstance(model, ByteTransformer):
|
||||
tgt_in = torch.cat([
|
||||
torch.zeros(x.shape[0], 1, device=device, dtype=torch.long),
|
||||
x[:, :-1]
|
||||
], dim=1)
|
||||
logits = model(x, tgt_in)
|
||||
|
||||
# only consider the last time step of the model where the full context
|
||||
# is available
|
||||
return logits[:, -1, :]
|
||||
return model(x)
|
||||
|
||||
from ..models import Model, AutoEncoder
|
||||
|
||||
def train(
|
||||
model: Model,
|
||||
|
|
@ -53,9 +38,10 @@ def train(
|
|||
y = y.long().to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
pred = _forward(model, x, device)
|
||||
pred = model(x)
|
||||
|
||||
if isinstance(model, AutoEncoder):
|
||||
pred = pred.squeeze(1)
|
||||
loss = loss_fn(pred, x.float() / 255.0)
|
||||
else:
|
||||
loss = loss_fn(pred, y)
|
||||
|
|
@ -74,9 +60,11 @@ def train(
|
|||
x = x.long().to(device)
|
||||
y = y.long().to(device)
|
||||
|
||||
pred = _forward(model, x, device)
|
||||
pred = model(x)
|
||||
|
||||
if isinstance(model, AutoEncoder):
|
||||
# compare the reconstructed vector with the original
|
||||
pred = pred.squeeze(1)
|
||||
loss = loss_fn(pred, x.float() / 255.0)
|
||||
else:
|
||||
loss = loss_fn(pred, y)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from abc import ABC, abstractmethod
|
|||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from src.models import Model
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""Abstract class for trainers."""
|
||||
|
|
@ -10,7 +12,7 @@ class Trainer(ABC):
|
|||
@abstractmethod
|
||||
def execute(
|
||||
self,
|
||||
model: nn.Module | type[nn.Module] | None,
|
||||
model: Model | type[Model],
|
||||
context_length: int,
|
||||
train_loader: DataLoader,
|
||||
validation_loader: DataLoader,
|
||||
|
|
|
|||
Reference in a new issue