From 249b228e7a2205a872d12f7e1723cc74306dad62 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Sun, 12 Jun 2022 02:18:42 +0200
Subject: [PATCH] sanity test

---
 eval/plot_std.py                           | 32 ++++++++++++----------
 eval/run_xtimes_cifar.sh                   |  2 +-
 eval/step_configs/config_cifar_sharing.ini |  2 +-
 3 files changed, 20 insertions(+), 16 deletions(-)

diff --git a/eval/plot_std.py b/eval/plot_std.py
index 5461657..1b92385 100644
--- a/eval/plot_std.py
+++ b/eval/plot_std.py
@@ -22,18 +22,18 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
     plt.legend(loc=loc)
 
 
-def plot_band(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
+def plot_band(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel, ax):
     cmap = plt.get_cmap("gist_rainbow")
-    plt.title(title)
-    plt.xlabel(xlabel)
+    ax.title(title)
+    ax.xlabel(xlabel)
     y_axis = list(means)
     print("label:", label)
     print("color: ", cmap(1 / nb_plots * pos))
-    plt.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), alpha=0.2)
-    plt.plot(
+    ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), alpha=0.2)
+    ax.plot(
         list(x_axis), y_axis, label=label, color=cmap(1 / nb_plots * pos)
     )
-    plt.legend(loc=loc)
+    ax.legend(loc=loc)
 
 
 def plot_results(path, epochs, global_epochs="True"):
@@ -59,8 +59,9 @@ def plot_results(path, epochs, global_epochs="True"):
     losses = {}
     losses_metrics = {"avg":[], "std":[], "name":[]}
     x_label = "global epochs"
-    plt.figure(1)
-    plt.subplot(131)
+    #plt.figure(1)
+    fig, ax = plt.subplots(1, 3, figsize=(18, 6))
+    #plt.subplot(131, figsize=(5.0, 3.0))
     for i, f in enumerate(train_loss):
         filepath = os.path.join(path, f)
         with open(filepath, "r") as inf:
@@ -107,11 +108,12 @@ def plot_results(path, epochs, global_epochs="True"):
             k,
             "upper right",
             x_label,
+            ax[0]
         )
     tlosses = {}
     tlosses_metrics = {"avg": [], "std": [], "name": []}
     x_label = "global epochs"
-    plt.subplot(132)
+    plt.subplot(132, figsize=(5.0, 3.0))
     for i, f in enumerate(test_loss):
         filepath = os.path.join(path, f)
         with open(filepath, "r") as inf:
@@ -152,16 +154,17 @@ def plot_results(path, epochs, global_epochs="True"):
             mean_of_means,
             mean_of_std,
             i,
-            len(losses),
+            len(tlosses),
             "Testing Loss",
             k,
             "upper right",
             x_label,
+            ax[1]
         )
 
     taccs = {}
     tacc_metrics = {"avg": [], "std": [], "name": []}
-    plt.subplot(133)
+    plt.subplot(133, figsize=(5.0, 3.0))
     for i, f in enumerate(test_acc):
 
         filepath = os.path.join(path, f)
@@ -205,11 +208,12 @@ def plot_results(path, epochs, global_epochs="True"):
                 mean_of_means,
                 mean_of_std,
                 i,
-                len(losses),
+                len(taccs),
                 "Testing Accuracy",
                 k,
                 "upper right",
                 x_label,
+                ax[2]
             )
 
     for metric, name in zip([losses_metrics, tlosses_metrics, tacc_metrics], ["losses_metrics", "tlosses_metrics", "accuracy_metrics"]):
@@ -237,8 +241,8 @@ def plot_results(path, epochs, global_epochs="True"):
         pf = pf.sort_values([name.split("_")[0]+"_values"], 0, ascending=False)
         pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"))
 
-    plt.savefig(os.path.join(path, "together.svg"), dpi=300, format="svg")
-    plt.savefig(os.path.join(path, "together.png"), dpi=300, format="png")
+    fig.savefig(os.path.join(path, "together.svg"), dpi=300, format="svg")
+    fig.savefig(os.path.join(path, "together.png"), dpi=300, format="png")
 
 
 if __name__ == "__main__":
diff --git a/eval/run_xtimes_cifar.sh b/eval/run_xtimes_cifar.sh
index 0a04f36..69d0c59 100755
--- a/eval/run_xtimes_cifar.sh
+++ b/eval/run_xtimes_cifar.sh
@@ -42,7 +42,7 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=1000
+global_epochs=400
 eval_file=testing.py
 log_level=INFO
 
diff --git a/eval/step_configs/config_cifar_sharing.ini b/eval/step_configs/config_cifar_sharing.ini
index 4f3fcca..8df88c5 100644
--- a/eval/step_configs/config_cifar_sharing.ini
+++ b/eval/step_configs/config_cifar_sharing.ini
@@ -8,7 +8,7 @@ test_dir = /mnt/nfs/shared/CIFAR
 sizes =
 random_seed = 99
 partition_niid = True
-shards = 1
+shards = 4
 
 [OPTIMIZER_PARAMS]
 optimizer_package = torch.optim
-- 
GitLab