fix: accuracy replaced by MSE loss, updated graphs

This commit is contained in:
RobinMeersman 2025-12-16 18:12:10 +01:00
parent 5bf45e47a5
commit 9cd37f156a
15 changed files with 38 additions and 38 deletions

View file

@ -2,6 +2,7 @@ import os
from contextlib import contextmanager
import torch
import torch.nn.functional as F
import src.process as p
@ -26,17 +27,16 @@ def timer():
def compare_files(original, decompressed: str | torch.Tensor):
with open(original, "rb") as file:
original = file.read()
original = torch.tensor(list(original), dtype=torch.uint8).cpu()
original = torch.tensor(list(original), dtype=torch.uint8).cpu().float()
if type(decompressed) == "str":
with open(decompressed, "rb") as file:
decompressed = file.read()
decompressed = torch.tensor(list(decompressed), dtype=torch.uint8).cpu()
decompressed = torch.tensor(list(decompressed), dtype=torch.uint8).cpu().float()
# count bytes matching
count = torch.sum(original == decompressed[:original.shape[0]])
accuracy = count / original.shape[0]
return accuracy
loss = F.mse_loss(decompressed[:original.shape[0]], original)
return loss
if __name__ == "__main__":
@ -79,7 +79,7 @@ if __name__ == "__main__":
with open("./results/compress/compression_results.csv", "w") as f:
# write header
f.write(
"model_type,model_name,context_length,dataset_type,input_file_name,original_file_size,compressed_file_size,match_percentage,compression_time,decompression_time\n"
"model_type,model_name,context_length,dataset_type,input_file_name,original_file_size,compressed_file_size,mse_loss,compression_time,decompression_time\n"
)
for model, context_length, model_name, files in models:
@ -110,7 +110,7 @@ if __name__ == "__main__":
decompression_time = t()
accuracy = compare_files(in_file, decompressed.flatten().cpu())
mse_loss = compare_files(in_file, decompressed.flatten().cpu())
og_file_len = os.path.getsize(in_file)
if compressed is None:
@ -121,5 +121,5 @@ if __name__ == "__main__":
os.remove("./output/tmp.pt")
f.write(
f"{model_name},{model},{context_length},{dataset_type},{file},{og_file_len},{compressed_size},{accuracy},{compression_time},{decompression_time}\n"
f"{model_name},{model},{context_length},{dataset_type},{file},{og_file_len},{compressed_size},{mse_loss},{compression_time},{decompression_time}\n"
)