fix: Also write numbers
This commit is contained in:
parent
e81997f129
commit
9250777691
1 changed files with 10 additions and 0 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
import csv
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
@ -34,6 +35,15 @@ def print_losses(train_losses: list[float], validation_losses: list[float], file
|
||||||
print(f"Saving losses to {filename}...")
|
print(f"Saving losses to {filename}...")
|
||||||
plt.savefig(filename)
|
plt.savefig(filename)
|
||||||
|
|
||||||
|
# Also write to CSV file
|
||||||
|
with open(filename.replace(".png", ".csv"), "w") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
writer.writerow(["epoch", "train_loss", "validation_loss"])
|
||||||
|
for i in range(len(train_losses)):
|
||||||
|
writer.writerow([i, train_losses[i], validation_losses[i]])
|
||||||
|
|
||||||
|
print("Done")
|
||||||
|
|
||||||
|
|
||||||
def determine_device():
|
def determine_device():
|
||||||
# NVIDIA GPUs (most HPC clusters)
|
# NVIDIA GPUs (most HPC clusters)
|
||||||
|
|
|
||||||
Reference in a new issue