import os from contextlib import contextmanager import torch import torch.nn.functional as F import src.process as p import time @contextmanager def timer(): start = time.time_ns() elapsed = None def get_elapsed(): nonlocal elapsed if elapsed is None: elapsed = time.time_ns() - start return elapsed yield get_elapsed get_elapsed() def compare_files(original, decompressed: str | torch.Tensor): with open(original, "rb") as file: original = file.read() original = torch.tensor(list(original), dtype=torch.uint8).cpu().float() if type(decompressed) == "str": with open(decompressed, "rb") as file: decompressed = file.read() decompressed = torch.tensor(list(decompressed), dtype=torch.uint8).cpu().float() # count bytes matching loss = F.mse_loss(decompressed[:original.shape[0]], original) return loss if __name__ == "__main__": files_genome = [ "genome.fna", "genome_large.fna", "genome_xlarge.fna" ] files_genome_cnn = [ "genome_small.fna", "genome_xsmall.fna", "genome_xxsmall.fna" ] files_enwik9 = [ "text.txt", "text_large.txt", "text_xlarge.txt" ] files_enwik9_cnn = [ "text_small.txt", "text_xsmall.txt", "text_xxsmall.txt" ] models = [ ("auto-genome-full-256.pt", 256, "autoencoder", files_genome), ("auto-genome-full-128.pt", 128, "autoencoder", files_genome), ("cnn-genome-full-256.pt", 256, "cnn", files_genome_cnn), ("cnn-genome-full-128.pt", 128, "cnn", files_genome_cnn), ("auto-enwik9-full-256.pt", 256, "autoencoder", files_enwik9), ("auto-enwik9-full-128.pt", 128, "autoencoder", files_enwik9), ("cnn-enwik9-full-256.pt", 256, "cnn", files_enwik9_cnn), ("cnn-enwik9-full-128.pt", 128, "cnn", files_enwik9_cnn), ] device = "cuda" if torch.cuda.is_available() else "cpu" with open("./results/compress/compression_results_auto_small.csv", "w") as f: # write header f.write( "model_type,model_name,context_length,dataset_type,input_file_name,original_file_size,compressed_file_size,mse_loss,compression_time,decompression_time\n" ) for model, context_length, model_name, files in models: dataset_type = "genome" if "genome" in model else "enwik9" for file in files: in_file = f"./data/compression_sets/{file}" model_path = f"./models/{model_name}/{model}" print(f"Running for model {model} and file {file}...") with timer() as t: compressed = p.compress( device=device, input_file=in_file, model_name=model_name, model_path=model_path, context_length=context_length, output_file="./output/tmp.pt" ) compression_time = t() with timer() as t: decompressed = p.decompress( device, model_name=model_name, model_path=model_path, context_length=context_length, input_file="./output/tmp.pt" ) decompression_time = t() mse_loss = compare_files(in_file, decompressed.flatten().cpu()) og_file_len = os.path.getsize(in_file) if compressed is None: compressed_size = os.path.getsize("./output/tmp.pt") else: compressed_size = 4 * compressed.shape[0] * compressed.shape[1] os.remove("./output/tmp.pt") f.write( f"{model_name},{model},{context_length},{dataset_type},{file},{og_file_len},{compressed_size},{mse_loss},{compression_time},{decompression_time}\n" )