diff --git a/src/utils/utils.py b/src/utils/utils.py index 4929f20..a9da3d4 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,4 @@ +import csv from os import path import matplotlib.pyplot as plt @@ -34,6 +35,15 @@ def print_losses(train_losses: list[float], validation_losses: list[float], file print(f"Saving losses to {filename}...") plt.savefig(filename) + # Also write to CSV file + with open(filename.replace(".png", ".csv"), "w") as f: + writer = csv.writer(f) + writer.writerow(["epoch", "train_loss", "validation_loss"]) + for i in range(len(train_losses)): + writer.writerow([i, train_losses[i], validation_losses[i]]) + + print("Done") + def determine_device(): # NVIDIA GPUs (most HPC clusters)