Skip to content
Snippets Groups Projects
Commit c1c07aff authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

updated plotting function.

parent cdc057cf
No related branches found
No related tags found
No related merge requests found
......@@ -5,8 +5,14 @@ import sys
import numpy as np
import pandas as pd
import matplotlib
from matplotlib import pyplot as plt
font = {'family' : 'normal',
'size' : 16}
matplotlib.rc('font', **font)
def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
cmap = plt.get_cmap("gist_rainbow")
......@@ -24,16 +30,18 @@ def plot(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel):
def plot_band(x_axis, means, stdevs, pos, nb_plots, title, label, loc, xlabel, ax):
cmap = plt.get_cmap("gist_rainbow")
ax.title(title)
ax.xlabel(xlabel)
print(type(ax))
ax.set_title(title)
ax.set_xlabel(xlabel)
y_axis = list(means)
print("label:", label)
print("color: ", cmap(1 / nb_plots * pos))
ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), alpha=0.2)
ax.plot(
list(x_axis), y_axis, label=label, color=cmap(1 / nb_plots * pos)
)
ax.legend(loc=loc)
ax.fill_between(list(x_axis), list(means - stdevs), list(means + stdevs), color=cmap(1 / nb_plots * pos), alpha=0.2)
if loc is not None:
ax.legend(loc=loc)
def plot_results(path, epochs, global_epochs="True"):
......@@ -61,6 +69,7 @@ def plot_results(path, epochs, global_epochs="True"):
x_label = "global epochs"
#plt.figure(1)
fig, ax = plt.subplots(1, 3, figsize=(18, 6))
plt.tight_layout(pad=2, w_pad=2, h_pad=2)
#plt.subplot(131, figsize=(5.0, 3.0))
for i, f in enumerate(train_loss):
filepath = os.path.join(path, f)
......@@ -68,7 +77,7 @@ def plot_results(path, epochs, global_epochs="True"):
results_csv = pd.read_csv(inf)
# Plot Training loss
#plt.figure(1)
norm_name = f[len("train_loss") + 1 : -len(":2022-03-24T17:54.csv")]
norm_name = f[len("train_loss") + 1 : -4].split(":")[0] # -len(":2022-03-24T17:54.csv")
if global_epochs:
rounds = results_csv["rounds"].iloc[0]
......@@ -94,7 +103,7 @@ def plot_results(path, epochs, global_epochs="True"):
mean_of_means.append(vals[0])
mean_of_std.append(vals[1])
mean_of_means = np.average(mean_of_means, axis = 0)
mean_of_std = np.average(mean_of_std, axis=0)
mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std))
losses_metrics["avg"].append(mean_of_means)
losses_metrics["std"].append(mean_of_std)
losses_metrics["name"].append(k)
......@@ -106,19 +115,19 @@ def plot_results(path, epochs, global_epochs="True"):
len(losses),
"Training Loss",
k,
"upper right",
None, #"upper right",
x_label,
ax[0]
)
tlosses = {}
tlosses_metrics = {"avg": [], "std": [], "name": []}
x_label = "global epochs"
plt.subplot(132, figsize=(5.0, 3.0))
#plt.subplot(132, figsize=(5.0, 3.0))
for i, f in enumerate(test_loss):
filepath = os.path.join(path, f)
with open(filepath, "r") as inf:
results_csv = pd.read_csv(inf)
norm_name = f[len("test_loss") + 1 : -len(":2022-03-24T17:54.csv")]
norm_name = f[len("test_loss") + 1 : -4].split(":")[0] # -len(":2022-03-24T17:54.csv")
if global_epochs:
rounds = results_csv["rounds"].iloc[0]
print("Rounds: ", rounds)
......@@ -145,7 +154,8 @@ def plot_results(path, epochs, global_epochs="True"):
mean_of_means.append(vals[0])
mean_of_std.append(vals[1])
mean_of_means = np.average(mean_of_means, axis=0)
mean_of_std = np.average(mean_of_std, axis=0)
mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std))
# np.average(mean_of_std, axis=0)
tlosses_metrics["avg"].append(mean_of_means)
tlosses_metrics["std"].append(mean_of_std)
tlosses_metrics["name"].append(k)
......@@ -157,21 +167,21 @@ def plot_results(path, epochs, global_epochs="True"):
len(tlosses),
"Testing Loss",
k,
"upper right",
None, #"upper right",
x_label,
ax[1]
)
taccs = {}
tacc_metrics = {"avg": [], "std": [], "name": []}
plt.subplot(133, figsize=(5.0, 3.0))
#plt.subplot(133, figsize=(5.0, 3.0))
for i, f in enumerate(test_acc):
filepath = os.path.join(path, f)
with open(filepath, "r") as inf:
results_csv = pd.read_csv(inf)
norm_name = f[len("test_loss") + 1: -len(":2022-03-24T17:54.csv")]
norm_name = f[len("test_acc") + 1: -4].split(":")[0] # -len(":2022-03-24T17:54.csv")
if global_epochs:
rounds = results_csv["rounds"].iloc[0]
print("Rounds: ", rounds)
......@@ -188,33 +198,33 @@ def plot_results(path, epochs, global_epochs="True"):
x_label = "communication rounds"
taccs.setdefault(norm_name, []).append((means, stdevs, x_axis))
for i, tmp in enumerate(taccs.items()):
(k, v) = tmp
mean_of_means = []
mean_of_std = []
x_axis = v[0][2]
for vals in v:
mean_of_means.append(vals[0])
mean_of_std.append(vals[1])
print(mean_of_means)
mean_of_means = np.average(mean_of_means, axis=0)
print(mean_of_means)
mean_of_std = np.average(mean_of_std, axis=0)
tacc_metrics["avg"].append(mean_of_means)
tacc_metrics["std"].append(mean_of_std)
tacc_metrics["name"].append(k)
plot_band(
x_axis,
mean_of_means,
mean_of_std,
i,
len(taccs),
"Testing Accuracy",
k,
"upper right",
x_label,
ax[2]
)
for i, tmp in enumerate(taccs.items()):
(k, v) = tmp
mean_of_means = []
mean_of_std = []
x_axis = v[0][2]
for vals in v:
mean_of_means.append(vals[0])
mean_of_std.append(vals[1])
print(mean_of_means)
mean_of_means = np.average(mean_of_means, axis=0)
print(mean_of_means)
mean_of_std = np.sqrt(np.sum(np.array(mean_of_std)**2, axis = 0)/len(mean_of_std))
tacc_metrics["avg"].append(mean_of_means)
tacc_metrics["std"].append(mean_of_std)
tacc_metrics["name"].append(k)
plot_band(
x_axis,
mean_of_means,
mean_of_std,
i,
len(taccs),
"Testing Accuracy",
k,
"lower right",
x_label,
ax[2]
)
for metric, name in zip([losses_metrics, tlosses_metrics, tacc_metrics], ["losses_metrics", "tlosses_metrics", "accuracy_metrics"]):
values = []
......@@ -239,7 +249,7 @@ def plot_results(path, epochs, global_epochs="True"):
print(pf)
pf = pf.sort_values([name.split("_")[0]+"_values"], 0, ascending=False)
pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"))
pf.to_csv(os.path.join(path, f"best_results_{name.split('_')[0]}.csv"), float_format="%.3f")
fig.savefig(os.path.join(path, "together.svg"), dpi=300, format="svg")
fig.savefig(os.path.join(path, "together.png"), dpi=300, format="png")
......@@ -252,4 +262,4 @@ if __name__ == "__main__":
# 2: the number of epochs / comm rounds to plot for,
# 3: True/False with True meaning plot global epochs and False plot communication rounds
print(sys.argv[1], sys.argv[2], sys.argv[3])
plot_results(sys.argv[1], sys.argv[2], sys.argv[3])
plot_results(sys.argv[1], sys.argv[2], sys.argv[3])
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment