feat: new CNN, start of creating graphs
This commit is contained in:
parent
17e0b52600
commit
5bb254d6c2
7 changed files with 151 additions and 49 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -3,3 +3,4 @@ __pycache__
|
||||||
data/
|
data/
|
||||||
saved_models/
|
saved_models/
|
||||||
results/
|
results/
|
||||||
|
output/
|
||||||
|
|
|
||||||
0
config/download_datasets.sh
Normal file → Executable file
0
config/download_datasets.sh
Normal file → Executable file
98
graphs.ipynb
Normal file
98
graphs.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -50,12 +50,16 @@ class AutoEncoder(Model):
|
||||||
"""
|
"""
|
||||||
x: torch.Tensor of floats
|
x: torch.Tensor of floats
|
||||||
"""
|
"""
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
return self.encoder(x)
|
return self.encoder(x)
|
||||||
|
|
||||||
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
def decode(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
x: torch.Tensor of floats
|
x: torch.Tensor of floats
|
||||||
"""
|
"""
|
||||||
|
if len(x.shape) == 2:
|
||||||
|
x = x.unsqueeze(1)
|
||||||
return self.decoder(x)
|
return self.decoder(x)
|
||||||
|
|
||||||
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||||
|
|
|
||||||
|
|
@ -1,57 +1,51 @@
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from src.models import Model
|
from src.models import Model
|
||||||
|
|
||||||
|
|
||||||
class CNNPredictor(Model):
|
class CNNPredictor(Model):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
vocab_size=256,
|
vocab_size: int = 256,
|
||||||
embed_dim=64,
|
hidden_dim: int = 128,
|
||||||
hidden_dim=128,
|
):
|
||||||
):
|
super().__init__(loss_function=nn.CrossEntropyLoss())
|
||||||
super().__init__(nn.CrossEntropyLoss())
|
|
||||||
|
|
||||||
# 1. Embedding: maps bytes (0–255) → vectors
|
# Treat bytes as a 1D signal with 1 channel
|
||||||
self.embed = nn.Embedding(vocab_size, embed_dim)
|
self.feature_extractor = nn.Sequential(
|
||||||
|
nn.Conv1d(1, hidden_dim, kernel_size=3, padding=1),
|
||||||
# 2. Convolutional feature extractor
|
|
||||||
self.conv_layers = nn.Sequential(
|
|
||||||
nn.Conv1d(embed_dim, hidden_dim, kernel_size=5, padding=2),
|
|
||||||
nn.BatchNorm1d(hidden_dim),
|
nn.BatchNorm1d(hidden_dim),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
nn.MaxPool1d(kernel_size=2),
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
|
nn.Conv1d(hidden_dim, 2 * hidden_dim, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(2 * hidden_dim),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv1d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
|
nn.MaxPool1d(kernel_size=2),
|
||||||
nn.BatchNorm1d(hidden_dim),
|
|
||||||
|
nn.Conv1d(2 * hidden_dim, 2 * hidden_dim, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm1d(2 * hidden_dim),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Global pooling to collapse sequence length
|
# Collapse sequence dimension → fixed-size representation
|
||||||
self.pool = nn.AdaptiveAvgPool1d(1) # → (B, hidden_channels, 1)
|
self.global_pool = nn.AdaptiveAvgPool1d(1) # (B, hidden_dim, 1)
|
||||||
|
|
||||||
# 4. Final classifier
|
# Classification head
|
||||||
self.fc = nn.Linear(hidden_dim, vocab_size) # → (B, 256)
|
self.classifier = nn.Linear(2 * hidden_dim, vocab_size)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.LongTensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
x: LongTensor of shape (B, 128), values 0-255
|
x: (B, L) LongTensor with values in [0, 255]
|
||||||
|
returns: logits (B, 256)
|
||||||
"""
|
"""
|
||||||
# embed: (B, 128, embed_dim)
|
# Convert bytes to float signal
|
||||||
x = self.embed(x)
|
x = x.float() / 255.0 # (B, L)
|
||||||
|
x = x.unsqueeze(1) # (B, 1, L)
|
||||||
# conv1d expects (B, C_in, L) → swap dims
|
|
||||||
x = x.transpose(1, 2) # (B, embed_dim, 128)
|
|
||||||
|
|
||||||
# apply CNN
|
|
||||||
x = self.conv_layers(x) # (B, hidden_channels, 128)
|
|
||||||
|
|
||||||
# global average pooling over sequence
|
|
||||||
x = self.pool(x).squeeze(-1) # (B, hidden_channels)
|
|
||||||
|
|
||||||
# final classifier
|
|
||||||
logits = self.fc(x) # (B, 256)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
features = self.feature_extractor(x) # (B, 2 * hidden_dim, L')
|
||||||
|
pooled = self.global_pool(features) # (B, 2 * hidden_dim, 1)
|
||||||
|
pooled = pooled.squeeze(-1) # (B, 2 * hidden_dim)
|
||||||
|
|
||||||
|
logits = self.classifier(pooled) # (B, 256)
|
||||||
|
return logits
|
||||||
|
|
@ -36,7 +36,6 @@ def ae_compress(
|
||||||
):
|
):
|
||||||
# Init AE
|
# Init AE
|
||||||
print("Initializing AE")
|
print("Initializing AE")
|
||||||
|
|
||||||
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
with contextlib.closing(reference_ae.BitOutputStream(open(output_file, "wb"))) as bitout:
|
||||||
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
enc = reference_ae.ArithmeticEncoder(len(byte_data), bitout)
|
||||||
|
|
||||||
|
|
@ -48,11 +47,6 @@ def ae_compress(
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
logits = model(context_tensor)
|
logits = model(context_tensor)
|
||||||
# normalize
|
|
||||||
mean = logits.mean(dim=-1, keepdim=True)
|
|
||||||
std = logits.std(dim=-1, keepdim=True)
|
|
||||||
logits = (logits - mean) / (std + 1e-6)
|
|
||||||
print(f"logits: {logits}")
|
|
||||||
probabilities = torch.softmax(logits[0], dim=-1)
|
probabilities = torch.softmax(logits[0], dim=-1)
|
||||||
print(f"probabilities: {probabilities}")
|
print(f"probabilities: {probabilities}")
|
||||||
probabilities = probabilities.detach()
|
probabilities = probabilities.detach()
|
||||||
|
|
@ -69,7 +63,7 @@ def chunk_data(x: bytes, context_length = 128) -> torch.Tensor:
|
||||||
row_count = math.ceil(shape / context_length)
|
row_count = math.ceil(shape / context_length)
|
||||||
pad_count = row_count * context_length - shape
|
pad_count = row_count * context_length - shape
|
||||||
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
|
tensor_data = nn.functional.pad(tensor_data, (0, pad_count), value=0)
|
||||||
return tensor_data.view(row_count, context_length)
|
return tensor_data.view(row_count, context_length).float() / 255.0
|
||||||
|
|
||||||
def auto_encoder_compress(
|
def auto_encoder_compress(
|
||||||
data: bytes,
|
data: bytes,
|
||||||
|
|
@ -82,9 +76,15 @@ def auto_encoder_compress(
|
||||||
# send the data to device
|
# send the data to device
|
||||||
tensor = chunk_data(data, context_length).to(device)
|
tensor = chunk_data(data, context_length).to(device)
|
||||||
|
|
||||||
|
print(f"input shape of compress: {len(data)} bytes")
|
||||||
|
|
||||||
# compress
|
# compress
|
||||||
output = model.encode(tensor)
|
output = model.encode(tensor)
|
||||||
print(output.shape)
|
print(f"output shape of compress: {4 * output.shape[0] * output.shape[1]} bytes")
|
||||||
|
|
||||||
|
# write output to file
|
||||||
|
print(f"saving to file {output_file}...")
|
||||||
|
torch.save(output.detach(), output_file)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -127,7 +127,13 @@ def compress(
|
||||||
tensor
|
tensor
|
||||||
)
|
)
|
||||||
case "autoencoder":
|
case "autoencoder":
|
||||||
auto_encoder_compress()
|
auto_encoder_compress(
|
||||||
|
byte_data,
|
||||||
|
model,
|
||||||
|
output_file,
|
||||||
|
context_length,
|
||||||
|
device
|
||||||
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unknown model type: {model_name}")
|
raise ValueError(f"Unknown model type: {model_name}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,7 @@ from ..models import Model, CNNPredictor, AutoEncoder
|
||||||
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
|
def create_model(trial: tr.Trial, model_cls: type[Model], context_length: int = 128):
|
||||||
if model_cls is CNNPredictor:
|
if model_cls is CNNPredictor:
|
||||||
return CNNPredictor(
|
return CNNPredictor(
|
||||||
hidden_dim=trial.suggest_int("hidden_dim", 64, 512, log=True),
|
hidden_dim=trial.suggest_int("hidden_dim", context_length, 512, log=True),
|
||||||
embed_dim=trial.suggest_int("embed_dim", 64, 512, log=True),
|
|
||||||
vocab_size=256,
|
vocab_size=256,
|
||||||
)
|
)
|
||||||
if model_cls is AutoEncoder:
|
if model_cls is AutoEncoder:
|
||||||
|
|
|
||||||
Reference in a new issue