diff --git a/eval/plot_std.py b/eval/plot_std.py
new file mode 100644
index 0000000000000000000000000000000000000000..5461657854bd7690cceab2e73610a5c67a3782e5
--- /dev/null
+++ b/eval/plot_std.py
@@ -0,0 +1,251 @@
+import distutils
+import json
+import os
+import sys
+
+import numpy as np
+import pandas as pd
+from matplotlib import pyplot as plt
+
+
+def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
+    cmap = plt.get_cmap("gist_rainbow")
+    plt.title(title)
+    plt.xlabel(xlabel)
+    y_axis = list(means)
+    err = list(stdevs)
+    print("label:", label)
+    print("color: ", cmap(1 / nb_plots * pos))
+    plt.errorbar(
+        list(x_axis), y_axis, yerr=err, label=label, color=cmap(1 / nb_plots * pos)
+    )
+    plt.legend(loc=loc)
+
+
+def plot_band(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
+    cmap = plt.get_cmap("gist_rainbow")
+    plt.title(title)
+    plt.xlabel(xlabel)
+    y_axis = list(means)
+    print("label:", label)
+    print("color: ", cmap(1 / nb_plots * pos))
+    plt.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), alpha=0.2)
+    plt.plot(
+        list(x_axis), y_axis, label=label, color=cmap(1 / nb_plots * pos)
+    )
+    plt.legend(loc=loc)
+
+
+def plot_results(path, epochs, global_epochs="True"):
+    if global_epochs.lower() in ['true', '1', 't', 'y', 'yes']:
+        global_epochs = True
+    else:
+        global_epochs = False
+    epochs = int(epochs)
+    # rounds = int(rounds)
+    folders = os.listdir(path)
+    folders.sort()
+    print("Reading folders from: ", path)
+    print("Folders: ", folders)
+    bytes_means, bytes_stdevs = {}, {}
+    meta_means, meta_stdevs = {}, {}
+    data_means, data_stdevs = {}, {}
+
+    files = os.listdir(path)
+    files = [f for f in files if f.endswith(".csv")]
+    train_loss = sorted([f for f in files if f.startswith("train_loss")])
+    test_acc = sorted([f for f in files if f.startswith("test_acc")])
+    test_loss = sorted([f for f in files if f.startswith("test_loss")])
+    losses = {}
+    losses_metrics = {"avg":[], "std":[], "name":[]}
+    x_label = "global epochs"
+    plt.figure(1)
+    plt.subplot(131)
+    for i, f in enumerate(train_loss):
+        filepath = os.path.join(path, f)
+        with open(filepath, "r") as inf:
+            results_csv = pd.read_csv(inf)
+        # Plot Training loss
+        #plt.figure(1)
+        norm_name = f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")]
+
+        if global_epochs:
+            rounds = results_csv["rounds"].iloc[0]
+            print("Rounds: ", rounds)
+            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            means = results_cr["mean"].to_numpy()
+            stdevs = results_cr["std"].to_numpy()
+            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_label = "global epochs"
+        else:
+            results_cr = results_csv[results_csv.rounds <= epochs]
+            means = results_cr["mean"].to_numpy()
+            stdevs = results_cr["std"].to_numpy()
+            x_axis = results_cr["rounds"].to_numpy()
+            x_label = "communication rounds"
+        losses.setdefault(norm_name, []).append((means, stdevs, x_axis))
+    for i , tmp in enumerate(losses.items()):
+        (k, v) = tmp
+        mean_of_means = []
+        mean_of_std = []
+        x_axis = v[0][2]
+        for vals in v:
+            mean_of_means.append(vals[0])
+            mean_of_std.append(vals[1])
+        mean_of_means = np.average(mean_of_means, axis = 0)
+        mean_of_std = np.average(mean_of_std, axis=0)
+        losses_metrics["avg"].append(mean_of_means)
+        losses_metrics["std"].append(mean_of_std)
+        losses_metrics["name"].append(k)
+        plot_band(
+            x_axis,
+            mean_of_means,
+            mean_of_std,
+            i,
+            len(losses),
+            "Training Loss",
+            k,
+            "upper right",
+            x_label,
+        )
+    tlosses = {}
+    tlosses_metrics = {"avg": [], "std": [], "name": []}
+    x_label = "global epochs"
+    plt.subplot(132)
+    for i, f in enumerate(test_loss):
+        filepath = os.path.join(path, f)
+        with open(filepath, "r") as inf:
+            results_csv = pd.read_csv(inf)
+        norm_name = f[len("test_loss") + 1 : -len(":2022-03-24T17:54.csv")]
+        if global_epochs:
+            rounds = results_csv["rounds"].iloc[0]
+            print("Rounds: ", rounds)
+            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            means = results_cr["mean"].to_numpy()
+            stdevs = results_cr["std"].to_numpy()
+            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_label = "global epochs"
+        else:
+            results_cr = results_csv[results_csv.rounds <= epochs]
+            means = results_cr["mean"].to_numpy()
+            stdevs = results_cr["std"].to_numpy()
+            x_axis = results_cr["rounds"].to_numpy()
+            x_label = "communication rounds"
+        print("x axis:", x_axis)
+        tlosses.setdefault(norm_name, []).append((means, stdevs, x_axis))
+
+    for i, tmp in enumerate(tlosses.items()):
+        (k, v) = tmp
+        mean_of_means = []
+        mean_of_std = []
+        x_axis = v[0][2]
+        for vals in v:
+            mean_of_means.append(vals[0])
+            mean_of_std.append(vals[1])
+        mean_of_means = np.average(mean_of_means, axis=0)
+        mean_of_std = np.average(mean_of_std, axis=0)
+        tlosses_metrics["avg"].append(mean_of_means)
+        tlosses_metrics["std"].append(mean_of_std)
+        tlosses_metrics["name"].append(k)
+        plot_band(
+            x_axis,
+            mean_of_means,
+            mean_of_std,
+            i,
+            len(losses),
+            "Testing Loss",
+            k,
+            "upper right",
+            x_label,
+        )
+
+    taccs = {}
+    tacc_metrics = {"avg": [], "std": [], "name": []}
+    plt.subplot(133)
+    for i, f in enumerate(test_acc):
+
+        filepath = os.path.join(path, f)
+        with open(filepath, "r") as inf:
+            results_csv = pd.read_csv(inf)
+
+        norm_name = f[len("test_loss") + 1: -len(":2022-03-24T17:54.csv")]
+        if global_epochs:
+            rounds = results_csv["rounds"].iloc[0]
+            print("Rounds: ", rounds)
+            results_cr = results_csv[results_csv.rounds <= epochs*rounds]
+            means = results_cr["mean"].to_numpy()
+            stdevs = results_cr["std"].to_numpy()
+            x_axis = results_cr["rounds"].to_numpy() / rounds # list(np.arange(0, len(means), 1))
+            x_label = "global epochs"
+        else:
+            results_cr = results_csv[results_csv.rounds <= epochs]
+            means = results_cr["mean"].to_numpy()
+            stdevs = results_cr["std"].to_numpy()
+            x_axis = results_cr["rounds"].to_numpy()
+            x_label = "communication rounds"
+        taccs.setdefault(norm_name, []).append((means, stdevs, x_axis))
+
+        for i, tmp in enumerate(taccs.items()):
+            (k, v) = tmp
+            mean_of_means = []
+            mean_of_std = []
+            x_axis = v[0][2]
+            for vals in v:
+                mean_of_means.append(vals[0])
+                mean_of_std.append(vals[1])
+            print(mean_of_means)
+            mean_of_means = np.average(mean_of_means, axis=0)
+            print(mean_of_means)
+            mean_of_std = np.average(mean_of_std, axis=0)
+            tacc_metrics["avg"].append(mean_of_means)
+            tacc_metrics["std"].append(mean_of_std)
+            tacc_metrics["name"].append(k)
+            plot_band(
+                x_axis,
+                mean_of_means,
+                mean_of_std,
+                i,
+                len(losses),
+                "Testing Accuracy",
+                k,
+                "upper right",
+                x_label,
+            )
+
+    for metric, name in zip([losses_metrics, tlosses_metrics, tacc_metrics], ["losses_metrics", "tlosses_metrics", "accuracy_metrics"]):
+        values = []
+        stds = []
+        names = []
+        for i, val in enumerate(metric["avg"]):
+            if "accuracy" in name:
+                idx = np.argmax(val)
+            else:
+                idx = np.argmin(val)
+            values.append(val[idx])
+            stds.append(metric["std"][i][idx])
+            names.append(metric["name"][i])
+            print(idx, len(val), val)
+        pf = pd.DataFrame(
+            {
+                name.split("_")[0]+"_values": values,
+                name.split("_")[0]+"_std": stds,
+            },
+            names,
+        )
+        print(pf)
+
+        pf = pf.sort_values([name.split("_")[0]+"_values"], 0, ascending=False)
+        pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"))
+
+    plt.savefig(os.path.join(path, "together.svg"), dpi=300, format="svg")
+    plt.savefig(os.path.join(path, "together.png"), dpi=300, format="png")
+
+
+if __name__ == "__main__":
+    assert len(sys.argv) == 4
+    # The args are:
+    # 1: the folder with the csv files,
+    # 2: the number of epochs / comm rounds to plot for,
+    # 3: True/False with True meaning plot global epochs and False plot communication rounds
+    print(sys.argv[1], sys.argv[2], sys.argv[3])
+    plot_results(sys.argv[1], sys.argv[2], sys.argv[3])