feat: Improved plots
This commit is contained in:
parent
b62f06018d
commit
d663f270e1
1 changed files with 98 additions and 25 deletions
|
|
@ -1,3 +1,4 @@
|
||||||
|
import glob
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
@ -67,7 +68,32 @@ def compression_ratios(df: pd.DataFrame, unique_labels, palette_dict) -> Figure:
|
||||||
|
|
||||||
plt.yticks(rotation=45, ha="right")
|
plt.yticks(rotation=45, ha="right")
|
||||||
|
|
||||||
ax.grid(True)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
return plt.gcf()
|
||||||
|
|
||||||
|
|
||||||
|
def compression_ratio_v_compression_time(df: pd.DataFrame, unique_labels, palette_dict, markers_dict) -> Figure:
|
||||||
|
plt.figure()
|
||||||
|
|
||||||
|
sns.scatterplot(
|
||||||
|
data=df,
|
||||||
|
x=RATE_COL,
|
||||||
|
y=COMPRESS_TIME_COL,
|
||||||
|
hue=LABEL_COL,
|
||||||
|
hue_order=unique_labels,
|
||||||
|
palette=palette_dict,
|
||||||
|
style=LABEL_COL,
|
||||||
|
style_order=unique_labels,
|
||||||
|
markers=markers_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
plt.legend(title='Compressor')
|
||||||
|
|
||||||
|
plt.xlabel('Compression ratio')
|
||||||
|
plt.ylabel('Compression time (s)')
|
||||||
|
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
return plt.gcf()
|
return plt.gcf()
|
||||||
|
|
||||||
|
|
@ -120,25 +146,35 @@ def filesize_v_mse(df: pd.DataFrame) -> Figure:
|
||||||
size = row[INPUT_SIZE_COL]
|
size = row[INPUT_SIZE_COL]
|
||||||
return f"{filename} ({size:.4f} MB)"
|
return f"{filename} ({size:.4f} MB)"
|
||||||
|
|
||||||
df['input_filename_size'] = df.apply(filename_and_size, axis=1)
|
def size(row):
|
||||||
|
full_name = row['input_filename']
|
||||||
|
# Strip prefix
|
||||||
|
size_name = full_name.lstrip('text').lstrip('genome').rstrip('txt').rstrip('fna')
|
||||||
|
size_name = f"*{size_name}*"
|
||||||
|
return size_name
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
df['input_filename_size'] = df.apply(filename_and_size, axis=1)
|
||||||
sns.barplot(
|
df['input_size'] = df.apply(size, axis=1)
|
||||||
|
|
||||||
|
g = sns.catplot(
|
||||||
data=df,
|
data=df,
|
||||||
y='input_filename',
|
kind="bar",
|
||||||
x=DISTORTION_COL,
|
x=DISTORTION_COL,
|
||||||
|
y='input_size',
|
||||||
|
col='training_dataset',
|
||||||
hue=CONTEXT_COL,
|
hue=CONTEXT_COL,
|
||||||
ax=ax,
|
palette='Set2',
|
||||||
palette='Set2'
|
height=5,
|
||||||
|
aspect=0.6
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.title('MSE for autoencoder')
|
g.set_axis_labels("MSE Loss", "Filename")
|
||||||
plt.xlabel('MSE')
|
g.set_titles("Autoencoder trained on {col_name}")
|
||||||
plt.ylabel('Filename')
|
|
||||||
plt.yticks(rotation=45, ha="right")
|
|
||||||
plt.legend(title='Context size')
|
|
||||||
|
|
||||||
plt.grid(True)
|
# plt.title('MSE for autoencoder')
|
||||||
|
# plt.yticks(rotation=45, ha="right")
|
||||||
|
# plt.legend(title='Context size')
|
||||||
|
g.tight_layout()
|
||||||
|
|
||||||
return plt.gcf()
|
return plt.gcf()
|
||||||
|
|
||||||
|
|
@ -164,7 +200,7 @@ def mse_losses(df: pd.DataFrame, unique_labels, palette_dict) -> Figure:
|
||||||
|
|
||||||
plt.yticks(rotation=45, ha="right")
|
plt.yticks(rotation=45, ha="right")
|
||||||
|
|
||||||
ax.grid(True)
|
ax.grid(True, alpha=0.3)
|
||||||
|
|
||||||
return plt.gcf()
|
return plt.gcf()
|
||||||
|
|
||||||
|
|
@ -238,8 +274,8 @@ def split_graph(
|
||||||
f.text(0.5, 0, x_axis_label, ha='center', va='center')
|
f.text(0.5, 0, x_axis_label, ha='center', va='center')
|
||||||
ax_left.set_ylabel(y_axis_label)
|
ax_left.set_ylabel(y_axis_label)
|
||||||
|
|
||||||
ax_left.grid(True)
|
ax_left.grid(True, alpha=0.3)
|
||||||
ax_right.grid(True)
|
ax_right.grid(True, alpha=0.3)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
return f, ax_left, ax_right
|
return f, ax_left, ax_right
|
||||||
|
|
@ -247,7 +283,7 @@ def split_graph(
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
df: pd.DataFrame, unique_labels, palette_dict, markers_dict,
|
df: pd.DataFrame, unique_labels, palette_dict, markers_dict,
|
||||||
tgt_dir: str, dpi: int = 300
|
tgt_dir: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Generate all the plots"""
|
"""Generate all the plots"""
|
||||||
# Make plots
|
# Make plots
|
||||||
|
|
@ -270,6 +306,11 @@ def generate(
|
||||||
os.path.join(tgt_dir, 'compression_ratios.png'),
|
os.path.join(tgt_dir, 'compression_ratios.png'),
|
||||||
bbox_inches='tight'
|
bbox_inches='tight'
|
||||||
)
|
)
|
||||||
|
compression_ratio_v_compression_time(df, unique_labels, palette_dict, markers_dict).savefig(
|
||||||
|
os.path.join(tgt_dir, 'compression_ratio_v_compression_time.png'),
|
||||||
|
bbox_inches='tight'
|
||||||
|
)
|
||||||
|
|
||||||
filesize_v_mse(df).savefig(
|
filesize_v_mse(df).savefig(
|
||||||
os.path.join(tgt_dir, 'filesize_mse.png'),
|
os.path.join(tgt_dir, 'filesize_mse.png'),
|
||||||
bbox_inches='tight'
|
bbox_inches='tight'
|
||||||
|
|
@ -280,7 +321,7 @@ def generate(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def setup(tgt_dir):
|
def setup(tgt_dir, dpi = 300):
|
||||||
# Create the targ directory if it does not exist
|
# Create the targ directory if it does not exist
|
||||||
os.makedirs(tgt_dir, exist_ok=True)
|
os.makedirs(tgt_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -288,7 +329,7 @@ def setup(tgt_dir):
|
||||||
params = {'text.usetex': True,
|
params = {'text.usetex': True,
|
||||||
'font.size': 11,
|
'font.size': 11,
|
||||||
'font.family': 'serif',
|
'font.family': 'serif',
|
||||||
'figure.dpi': 300,
|
'figure.dpi': dpi,
|
||||||
}
|
}
|
||||||
plt.rcParams.update(params)
|
plt.rcParams.update(params)
|
||||||
|
|
||||||
|
|
@ -327,8 +368,8 @@ def main():
|
||||||
df = pd.read_csv("measurements.csv")
|
df = pd.read_csv("measurements.csv")
|
||||||
|
|
||||||
tgt_dir = "figures"
|
tgt_dir = "figures"
|
||||||
setup(tgt_dir)
|
setup(tgt_dir, 300)
|
||||||
generate(*preprocessing(df), tgt_dir=tgt_dir, dpi=150)
|
generate(*preprocessing(df), tgt_dir=tgt_dir)
|
||||||
|
|
||||||
|
|
||||||
def old_results():
|
def old_results():
|
||||||
|
|
@ -395,7 +436,7 @@ def old_results():
|
||||||
# plt.title(f"{model_type.capitalize()} compressed file evolution: {dataset_type}")
|
# plt.title(f"{model_type.capitalize()} compressed file evolution: {dataset_type}")
|
||||||
plt.xlabel("Original file size (MB)")
|
plt.xlabel("Original file size (MB)")
|
||||||
plt.ylabel("Compressed file size (MB)")
|
plt.ylabel("Compressed file size (MB)")
|
||||||
plt.ylim(0, model_df["compressed_file_size"].max() / 1e6)
|
plt.ylim(0, df[df["model_type"] == model_type]["compressed_file_size"].max() / 1e6)
|
||||||
plt.legend(title="Context size")
|
plt.legend(title="Context size")
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(f"./graphs/{model_type}_{dataset_type}_compression_ratio.png")
|
plt.savefig(f"./graphs/{model_type}_{dataset_type}_compression_ratio.png")
|
||||||
|
|
@ -428,7 +469,7 @@ def old_results():
|
||||||
linestyle=linestyle
|
linestyle=linestyle
|
||||||
)
|
)
|
||||||
|
|
||||||
plt.grid(True)
|
plt.grid(True, alpha=0.3)
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.title(f"(Log-linear) Extrapolated execution time for CNN")
|
plt.title(f"(Log-linear) Extrapolated execution time for CNN")
|
||||||
# plt.xscale('log')
|
# plt.xscale('log')
|
||||||
|
|
@ -466,11 +507,43 @@ def old_results():
|
||||||
plt.xlabel("MSE loss")
|
plt.xlabel("MSE loss")
|
||||||
plt.ylabel("Filename")
|
plt.ylabel("Filename")
|
||||||
plt.legend()
|
plt.legend()
|
||||||
plt.grid(True)
|
plt.grid(True, alpha=0.3)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plt.savefig(f"./graphs/{model_type}_loss.png")
|
plt.savefig(f"./graphs/{model_type}_loss.png")
|
||||||
|
|
||||||
|
|
||||||
|
def training_loss(df, loss) -> Figure:
|
||||||
|
plt.figure(figsize=(4, 3))
|
||||||
|
|
||||||
|
plt.plot(df['train_loss'], label="Training loss")
|
||||||
|
plt.plot(df['validation_loss'], label="Validation losses")
|
||||||
|
|
||||||
|
plt.xlabel("Epoch")
|
||||||
|
plt.ylabel(loss)
|
||||||
|
|
||||||
|
if loss == 'MSE Loss':
|
||||||
|
ylim = 0.01
|
||||||
|
else:
|
||||||
|
ylim = 6
|
||||||
|
plt.ylim(0, ylim)
|
||||||
|
|
||||||
|
plt.legend()
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
return plt.gcf()
|
||||||
|
|
||||||
|
|
||||||
|
def make_training_graphs(models_dir, loss):
|
||||||
|
for csv in glob.glob(models_dir + '/*.csv'):
|
||||||
|
df = pd.read_csv(csv)
|
||||||
|
training_loss(df, loss).savefig(
|
||||||
|
csv.replace('.csv', '.png'),
|
||||||
|
bbox_inches='tight',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
old_results()
|
# old_results()
|
||||||
|
make_training_graphs('../models/autoencoder', 'MSE Loss')
|
||||||
|
make_training_graphs('../models/cnn', 'Cross Entropy Loss')
|
||||||
|
|
|
||||||
Reference in a new issue