From c1c07aff74987d0b1f16f50d8da7531ae3570278 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Fri, 17 Jun 2022 02:48:38 +0200
Subject: [PATCH] updated plotting function.

---
 eval/plot_std.py | 94 ++++++++++++++++++++++++++----------------------
 1 file changed, 52 insertions(+), 42 deletions(-)

diff --git a/eval/plot_std.py b/eval/plot_std.py
index 1b92385..c5255f9 100644
--- a/eval/plot_std.py
+++ b/eval/plot_std.py
@@ -5,8 +5,14 @@ import sys
 
 import numpy as np
 import pandas as pd
+import matplotlib
 from matplotlib import pyplot as plt
 
+font = {'family' : 'normal',
+        'size'   : 16}
+
+matplotlib.rc('font', **font)
+
 
 def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
     cmap = plt.get_cmap("gist_rainbow")
@@ -24,16 +30,18 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
 
 def plot_band(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel, ax):
     cmap = plt.get_cmap("gist_rainbow")
-    ax.title(title)
-    ax.xlabel(xlabel)
+    print(type(ax))
+    ax.set_title(title)
+    ax.set_xlabel(xlabel)
     y_axis = list(means)
     print("label:", label)
     print("color: ", cmap(1 / nb_plots * pos))
-    ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), alpha=0.2)
     ax.plot(
         list(x_axis), y_axis, label=label, color=cmap(1 / nb_plots * pos)
     )
-    ax.legend(loc=loc)
+    ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), color=cmap(1 / nb_plots * pos), alpha=0.2)
+    if loc is not None:
+        ax.legend(loc=loc)
 
 
 def plot_results(path, epochs, global_epochs="True"):
@@ -61,6 +69,7 @@ def plot_results(path, epochs, global_epochs="True"):
     x_label = "global epochs"
     #plt.figure(1)
     fig, ax = plt.subplots(1, 3, figsize=(18, 6))
+    plt.tight_layout(pad=2, w_pad=2, h_pad=2)
     #plt.subplot(131, figsize=(5.0, 3.0))
     for i, f in enumerate(train_loss):
         filepath = os.path.join(path, f)
@@ -68,7 +77,7 @@ def plot_results(path, epochs, global_epochs="True"):
             results_csv = pd.read_csv(inf)
         # Plot Training loss
         #plt.figure(1)
-        norm_name = f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")]
+        norm_name = f[len("train_loss") + 1 : -4].split(":")[0] # -len(":2022-03-24T17:54.csv")
 
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
@@ -94,7 +103,7 @@ def plot_results(path, epochs, global_epochs="True"):
             mean_of_means.append(vals[0])
             mean_of_std.append(vals[1])
         mean_of_means = np.average(mean_of_means, axis = 0)
-        mean_of_std = np.average(mean_of_std, axis=0)
+        mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std))
         losses_metrics["avg"].append(mean_of_means)
         losses_metrics["std"].append(mean_of_std)
         losses_metrics["name"].append(k)
@@ -106,19 +115,19 @@ def plot_results(path, epochs, global_epochs="True"):
             len(losses),
             "Training Loss",
             k,
-            "upper right",
+            None, #"upper right",
             x_label,
             ax[0]
         )
     tlosses = {}
     tlosses_metrics = {"avg": [], "std": [], "name": []}
     x_label = "global epochs"
-    plt.subplot(132, figsize=(5.0, 3.0))
+    #plt.subplot(132, figsize=(5.0, 3.0))
     for i, f in enumerate(test_loss):
         filepath = os.path.join(path, f)
         with open(filepath, "r") as inf:
             results_csv = pd.read_csv(inf)
-        norm_name = f[len("test_loss") + 1 : -len(":2022-03-24T17:54.csv")]
+        norm_name = f[len("test_loss") + 1 : -4].split(":")[0] # -len(":2022-03-24T17:54.csv")
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
@@ -145,7 +154,8 @@ def plot_results(path, epochs, global_epochs="True"):
             mean_of_means.append(vals[0])
             mean_of_std.append(vals[1])
         mean_of_means = np.average(mean_of_means, axis=0)
-        mean_of_std = np.average(mean_of_std, axis=0)
+        mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std))
+ # np.average(mean_of_std, axis=0)
         tlosses_metrics["avg"].append(mean_of_means)
         tlosses_metrics["std"].append(mean_of_std)
         tlosses_metrics["name"].append(k)
@@ -157,21 +167,21 @@ def plot_results(path, epochs, global_epochs="True"):
             len(tlosses),
             "Testing Loss",
             k,
-            "upper right",
+            None, #"upper right",
             x_label,
             ax[1]
         )
 
     taccs = {}
     tacc_metrics = {"avg": [], "std": [], "name": []}
-    plt.subplot(133, figsize=(5.0, 3.0))
+    #plt.subplot(133, figsize=(5.0, 3.0))
     for i, f in enumerate(test_acc):
 
         filepath = os.path.join(path, f)
         with open(filepath, "r") as inf:
             results_csv = pd.read_csv(inf)
 
-        norm_name = f[len("test_loss") + 1: -len(":2022-03-24T17:54.csv")]
+        norm_name = f[len("test_acc") + 1: -4].split(":")[0] # -len(":2022-03-24T17:54.csv")
         if global_epochs:
             rounds = results_csv["rounds"].iloc[0]
             print("Rounds: ", rounds)
@@ -188,33 +198,33 @@ def plot_results(path, epochs, global_epochs="True"):
             x_label = "communication rounds"
         taccs.setdefault(norm_name, []).append((means, stdevs, x_axis))
 
-        for i, tmp in enumerate(taccs.items()):
-            (k, v) = tmp
-            mean_of_means = []
-            mean_of_std = []
-            x_axis = v[0][2]
-            for vals in v:
-                mean_of_means.append(vals[0])
-                mean_of_std.append(vals[1])
-            print(mean_of_means)
-            mean_of_means = np.average(mean_of_means, axis=0)
-            print(mean_of_means)
-            mean_of_std = np.average(mean_of_std, axis=0)
-            tacc_metrics["avg"].append(mean_of_means)
-            tacc_metrics["std"].append(mean_of_std)
-            tacc_metrics["name"].append(k)
-            plot_band(
-                x_axis,
-                mean_of_means,
-                mean_of_std,
-                i,
-                len(taccs),
-                "Testing Accuracy",
-                k,
-                "upper right",
-                x_label,
-                ax[2]
-            )
+    for i, tmp in enumerate(taccs.items()):
+        (k, v) = tmp
+        mean_of_means = []
+        mean_of_std = []
+        x_axis = v[0][2]
+        for vals in v:
+            mean_of_means.append(vals[0])
+            mean_of_std.append(vals[1])
+        print(mean_of_means)
+        mean_of_means = np.average(mean_of_means, axis=0)
+        print(mean_of_means)
+        mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std))
+        tacc_metrics["avg"].append(mean_of_means)
+        tacc_metrics["std"].append(mean_of_std)
+        tacc_metrics["name"].append(k)
+        plot_band(
+            x_axis,
+            mean_of_means,
+            mean_of_std,
+            i,
+            len(taccs),
+            "Testing Accuracy",
+            k,
+            "lower right",
+            x_label,
+            ax[2]
+        )
 
     for metric, name in zip([losses_metrics, tlosses_metrics, tacc_metrics], ["losses_metrics", "tlosses_metrics", "accuracy_metrics"]):
         values = []
@@ -239,7 +249,7 @@ def plot_results(path, epochs, global_epochs="True"):
         print(pf)
 
         pf = pf.sort_values([name.split("_")[0]+"_values"], 0, ascending=False)
-        pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"))
+        pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"), float_format="%.3f")
 
     fig.savefig(os.path.join(path, "together.svg"), dpi=300, format="svg")
     fig.savefig(os.path.join(path, "together.png"), dpi=300, format="png")
@@ -252,4 +262,4 @@ if __name__ == "__main__":
     # 2: the number of epochs / comm rounds to plot for,
     # 3: True/False with True meaning plot global epochs and False plot communication rounds
     print(sys.argv[1], sys.argv[2], sys.argv[3])
-    plot_results(sys.argv[1], sys.argv[2], sys.argv[3])
+    plot_results(sys.argv[1], sys.argv[2], sys.argv[3])
\ No newline at end of file
-- 
GitLab