From c1c07aff74987d0b1f16f50d8da7531ae3570278 Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Fri, 17 Jun 2022 02:48:38 +0200 Subject: [PATCH] updated plotting function. --- eval/plot_std.py | 94 ++++++++++++++++++++++++++---------------------- 1 file changed, 52 insertions(+), 42 deletions(-) diff --git a/eval/plot_std.py b/eval/plot_std.py index 1b92385..c5255f9 100644 --- a/eval/plot_std.py +++ b/eval/plot_std.py @@ -5,8 +5,14 @@ import sys import numpy as np import pandas as pd +import matplotlib from matplotlib import pyplot as plt +font = {'family' : 'normal', + 'size' : 16} + +matplotlib.rc('font', **font) + def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel): cmap = plt.get_cmap("gist_rainbow") @@ -24,16 +30,18 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel): def plot_band(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel, ax): cmap = plt.get_cmap("gist_rainbow") - ax.title(title) - ax.xlabel(xlabel) + print(type(ax)) + ax.set_title(title) + ax.set_xlabel(xlabel) y_axis = list(means) print("label:", label) print("color: ", cmap(1 / nb_plots * pos)) - ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), alpha=0.2) ax.plot( list(x_axis), y_axis, label=label, color=cmap(1 / nb_plots * pos) ) - ax.legend(loc=loc) + ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), color=cmap(1 / nb_plots * pos), alpha=0.2) + if loc is not None: + ax.legend(loc=loc) def plot_results(path, epochs, global_epochs="True"): @@ -61,6 +69,7 @@ def plot_results(path, epochs, global_epochs="True"): x_label = "global epochs" #plt.figure(1) fig, ax = plt.subplots(1, 3, figsize=(18, 6)) + plt.tight_layout(pad=2, w_pad=2, h_pad=2) #plt.subplot(131, figsize=(5.0, 3.0)) for i, f in enumerate(train_loss): filepath = os.path.join(path, f) @@ -68,7 +77,7 @@ def plot_results(path, epochs, global_epochs="True"): results_csv = pd.read_csv(inf) # Plot Training loss #plt.figure(1) - norm_name = f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")] + norm_name = f[len("train_loss") + 1 : -4].split(":")[0] # -len(":2022-03-24T17:54.csv") if global_epochs: rounds = results_csv["rounds"].iloc[0] @@ -94,7 +103,7 @@ def plot_results(path, epochs, global_epochs="True"): mean_of_means.append(vals[0]) mean_of_std.append(vals[1]) mean_of_means = np.average(mean_of_means, axis = 0) - mean_of_std = np.average(mean_of_std, axis=0) + mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std)) losses_metrics["avg"].append(mean_of_means) losses_metrics["std"].append(mean_of_std) losses_metrics["name"].append(k) @@ -106,19 +115,19 @@ def plot_results(path, epochs, global_epochs="True"): len(losses), "Training Loss", k, - "upper right", + None, #"upper right", x_label, ax[0] ) tlosses = {} tlosses_metrics = {"avg": [], "std": [], "name": []} x_label = "global epochs" - plt.subplot(132, figsize=(5.0, 3.0)) + #plt.subplot(132, figsize=(5.0, 3.0)) for i, f in enumerate(test_loss): filepath = os.path.join(path, f) with open(filepath, "r") as inf: results_csv = pd.read_csv(inf) - norm_name = f[len("test_loss") + 1 : -len(":2022-03-24T17:54.csv")] + norm_name = f[len("test_loss") + 1 : -4].split(":")[0] # -len(":2022-03-24T17:54.csv") if global_epochs: rounds = results_csv["rounds"].iloc[0] print("Rounds: ", rounds) @@ -145,7 +154,8 @@ def plot_results(path, epochs, global_epochs="True"): mean_of_means.append(vals[0]) mean_of_std.append(vals[1]) mean_of_means = np.average(mean_of_means, axis=0) - mean_of_std = np.average(mean_of_std, axis=0) + mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std)) + # np.average(mean_of_std, axis=0) tlosses_metrics["avg"].append(mean_of_means) tlosses_metrics["std"].append(mean_of_std) tlosses_metrics["name"].append(k) @@ -157,21 +167,21 @@ def plot_results(path, epochs, global_epochs="True"): len(tlosses), "Testing Loss", k, - "upper right", + None, #"upper right", x_label, ax[1] ) taccs = {} tacc_metrics = {"avg": [], "std": [], "name": []} - plt.subplot(133, figsize=(5.0, 3.0)) + #plt.subplot(133, figsize=(5.0, 3.0)) for i, f in enumerate(test_acc): filepath = os.path.join(path, f) with open(filepath, "r") as inf: results_csv = pd.read_csv(inf) - norm_name = f[len("test_loss") + 1: -len(":2022-03-24T17:54.csv")] + norm_name = f[len("test_acc") + 1: -4].split(":")[0] # -len(":2022-03-24T17:54.csv") if global_epochs: rounds = results_csv["rounds"].iloc[0] print("Rounds: ", rounds) @@ -188,33 +198,33 @@ def plot_results(path, epochs, global_epochs="True"): x_label = "communication rounds" taccs.setdefault(norm_name, []).append((means, stdevs, x_axis)) - for i, tmp in enumerate(taccs.items()): - (k, v) = tmp - mean_of_means = [] - mean_of_std = [] - x_axis = v[0][2] - for vals in v: - mean_of_means.append(vals[0]) - mean_of_std.append(vals[1]) - print(mean_of_means) - mean_of_means = np.average(mean_of_means, axis=0) - print(mean_of_means) - mean_of_std = np.average(mean_of_std, axis=0) - tacc_metrics["avg"].append(mean_of_means) - tacc_metrics["std"].append(mean_of_std) - tacc_metrics["name"].append(k) - plot_band( - x_axis, - mean_of_means, - mean_of_std, - i, - len(taccs), - "Testing Accuracy", - k, - "upper right", - x_label, - ax[2] - ) + for i, tmp in enumerate(taccs.items()): + (k, v) = tmp + mean_of_means = [] + mean_of_std = [] + x_axis = v[0][2] + for vals in v: + mean_of_means.append(vals[0]) + mean_of_std.append(vals[1]) + print(mean_of_means) + mean_of_means = np.average(mean_of_means, axis=0) + print(mean_of_means) + mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std)) + tacc_metrics["avg"].append(mean_of_means) + tacc_metrics["std"].append(mean_of_std) + tacc_metrics["name"].append(k) + plot_band( + x_axis, + mean_of_means, + mean_of_std, + i, + len(taccs), + "Testing Accuracy", + k, + "lower right", + x_label, + ax[2] + ) for metric, name in zip([losses_metrics, tlosses_metrics, tacc_metrics], ["losses_metrics", "tlosses_metrics", "accuracy_metrics"]): values = [] @@ -239,7 +249,7 @@ def plot_results(path, epochs, global_epochs="True"): print(pf) pf = pf.sort_values([name.split("_")[0]+"_values"], 0, ascending=False) - pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv")) + pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"), float_format="%.3f") fig.savefig(os.path.join(path, "together.svg"), dpi=300, format="svg") fig.savefig(os.path.join(path, "together.png"), dpi=300, format="png") @@ -252,4 +262,4 @@ if __name__ == "__main__": # 2: the number of epochs / comm rounds to plot for, # 3: True/False with True meaning plot global epochs and False plot communication rounds print(sys.argv[1], sys.argv[2], sys.argv[3]) - plot_results(sys.argv[1], sys.argv[2], sys.argv[3]) + plot_results(sys.argv[1], sys.argv[2], sys.argv[3]) \ No newline at end of file -- GitLab