123 lines
3.9 KiB
Python
123 lines
3.9 KiB
Python
import os
|
|
from argparse import ArgumentParser
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
|
|
import src.process as p
|
|
|
|
import time
|
|
|
|
|
|
@contextmanager
|
|
def timer():
|
|
start = time.time_ns()
|
|
elapsed = None
|
|
|
|
def get_elapsed():
|
|
nonlocal elapsed
|
|
if elapsed is None:
|
|
elapsed = time.time_ns() - start
|
|
return elapsed
|
|
|
|
yield get_elapsed
|
|
get_elapsed()
|
|
|
|
|
|
def compare_files(original, decompressed: str | torch.Tensor):
|
|
with open(original, "rb") as file:
|
|
original = file.read()
|
|
original = torch.tensor(list(original), dtype=torch.uint8).cpu()
|
|
|
|
if type(decompressed) == "str":
|
|
with open(decompressed, "rb") as file:
|
|
decompressed = file.read()
|
|
decompressed = torch.tensor(list(decompressed), dtype=torch.uint8).cpu()
|
|
|
|
# count bytes matching
|
|
count = torch.sum(original == decompressed[:original.shape[0]])
|
|
accuracy = count / original.shape[0]
|
|
return accuracy
|
|
|
|
|
|
if __name__ == "__main__":
|
|
files_genome = [
|
|
"genome.fna",
|
|
"genome_large.fna",
|
|
"genome_xlarge.fna"
|
|
]
|
|
|
|
files_genome_cnn = [
|
|
"genome_small.fna",
|
|
"genome_xsmall.fna",
|
|
"genome_xxsmall.fna"
|
|
]
|
|
|
|
files_enwik9 = [
|
|
# "text.txt",
|
|
# "txt_large.txt",
|
|
# "txt_xlarge.txt"
|
|
]
|
|
|
|
files_enwik9_cnn = [
|
|
|
|
]
|
|
|
|
models = [
|
|
("auto-genome-full-256.pt", 256, "autoencoder", files_genome),
|
|
("auto-genome-full-128.pt", 128, "autoencoder", files_genome),
|
|
("cnn-genome-full-256.pt", 256, "cnn", files_genome_cnn),
|
|
("cnn-genome-full-128.pt", 128, "cnn", files_genome_cnn),
|
|
("auto-enwik9-full-256.pt", 256, "autoencoder", files_enwik9),
|
|
("auto-enwik9-full-128", 128, "autoencoder", files_enwik9),
|
|
("cnn-enwik9-full-256.pt", 256, "cnn", files_enwik9_cnn),
|
|
("cnn-enwik9-full-128.pt", 128, "cnn", files_enwik9_cnn),
|
|
]
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
with open("./results/compress/compression_results.csv", "w") as f:
|
|
# write header
|
|
f.write(
|
|
"model_type,model_name,context_length,input_file_name,original_file_size,compressed_file_size,match_percentage,compression_time,decompression_time\n"
|
|
)
|
|
|
|
for model, context_length, model_name, files in models:
|
|
for file in files:
|
|
in_file = f"./data/compression_sets/{file}"
|
|
model_path = f"./models/{model_name}/{model}"
|
|
print(f"Running for model {model} and file {file}...")
|
|
with timer() as t:
|
|
compressed = p.compress(
|
|
device=device,
|
|
input_file=in_file,
|
|
model_name=model_name,
|
|
model_path=model_path,
|
|
context_length=context_length,
|
|
output_file="./output/tmp.pt"
|
|
)
|
|
compression_time = t()
|
|
|
|
with timer() as t:
|
|
decompressed = p.decompress(
|
|
device,
|
|
model_name=model_name,
|
|
model_path=model_path,
|
|
context_length=context_length,
|
|
input_file="./output/tmp.pt"
|
|
)
|
|
decompression_time = t()
|
|
|
|
|
|
accuracy = compare_files(in_file, decompressed.flatten().cpu())
|
|
|
|
og_file_len = os.path.getsize(in_file)
|
|
if compressed is None:
|
|
compressed_size = os.path.getsize("./output/tmp.pt")
|
|
else:
|
|
compressed_size = 4 * compressed.shape[0] * compressed.shape[1]
|
|
|
|
os.remove("./output/tmp.pt")
|
|
|
|
f.write(
|
|
f"{model_name},{model},{context_length},{file},{og_file_len},{compressed_size},{accuracy},{compression_time},{decompression_time}\n"
|
|
)
|