fix: accuracy replaced by MSE loss, updated graphs
This commit is contained in:
parent
5bf45e47a5
commit
9cd37f156a
15 changed files with 38 additions and 38 deletions
16
measure.py
16
measure.py
|
|
@ -2,6 +2,7 @@ import os
|
|||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import src.process as p
|
||||
|
||||
|
|
@ -26,17 +27,16 @@ def timer():
|
|||
def compare_files(original, decompressed: str | torch.Tensor):
|
||||
with open(original, "rb") as file:
|
||||
original = file.read()
|
||||
original = torch.tensor(list(original), dtype=torch.uint8).cpu()
|
||||
original = torch.tensor(list(original), dtype=torch.uint8).cpu().float()
|
||||
|
||||
if type(decompressed) == "str":
|
||||
with open(decompressed, "rb") as file:
|
||||
decompressed = file.read()
|
||||
decompressed = torch.tensor(list(decompressed), dtype=torch.uint8).cpu()
|
||||
decompressed = torch.tensor(list(decompressed), dtype=torch.uint8).cpu().float()
|
||||
|
||||
# count bytes matching
|
||||
count = torch.sum(original == decompressed[:original.shape[0]])
|
||||
accuracy = count / original.shape[0]
|
||||
return accuracy
|
||||
loss = F.mse_loss(decompressed[:original.shape[0]], original)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -79,7 +79,7 @@ if __name__ == "__main__":
|
|||
with open("./results/compress/compression_results.csv", "w") as f:
|
||||
# write header
|
||||
f.write(
|
||||
"model_type,model_name,context_length,dataset_type,input_file_name,original_file_size,compressed_file_size,match_percentage,compression_time,decompression_time\n"
|
||||
"model_type,model_name,context_length,dataset_type,input_file_name,original_file_size,compressed_file_size,mse_loss,compression_time,decompression_time\n"
|
||||
)
|
||||
|
||||
for model, context_length, model_name, files in models:
|
||||
|
|
@ -110,7 +110,7 @@ if __name__ == "__main__":
|
|||
decompression_time = t()
|
||||
|
||||
|
||||
accuracy = compare_files(in_file, decompressed.flatten().cpu())
|
||||
mse_loss = compare_files(in_file, decompressed.flatten().cpu())
|
||||
|
||||
og_file_len = os.path.getsize(in_file)
|
||||
if compressed is None:
|
||||
|
|
@ -121,5 +121,5 @@ if __name__ == "__main__":
|
|||
os.remove("./output/tmp.pt")
|
||||
|
||||
f.write(
|
||||
f"{model_name},{model},{context_length},{dataset_type},{file},{og_file_len},{compressed_size},{accuracy},{compression_time},{decompression_time}\n"
|
||||
f"{model_name},{model},{context_length},{dataset_type},{file},{og_file_len},{compressed_size},{mse_loss},{compression_time},{decompression_time}\n"
|
||||
)
|
||||
|
|
|
|||
Reference in a new issue