Skip to content
Snippets Groups Projects
Commit 40ccc424 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Improve result json

parent 327827e0
No related branches found
No related tags found
No related merge requests found
......@@ -173,9 +173,6 @@ class Node:
self.testset = self.dataset.get_testset()
rounds_to_test = test_after
self.train_loss = dict()
self.test_loss = dict()
self.test_acc = dict()
for iteration in range(iterations):
logging.info("Starting training iteration: %d", iteration)
......@@ -188,7 +185,25 @@ class Node:
self.trainer.reset_optimizer(self.optimizer)
loss_after_sharing = self.trainer.eval_loss(self.dataset)
self.train_loss[iteration + 1] = loss_after_sharing
if iteration:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {"train_loss": {}, "test_loss": {}, "test_acc": {}}
results_dict["train_loss"][iteration + 1] = loss_after_sharing
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_train_loss.png".format(self.rank)),
)
rounds_to_test -= 1
......@@ -196,42 +211,27 @@ class Node:
logging.info("Evaluating on test set.")
rounds_to_test = test_after
ta, tl = self.dataset.test(self.model, self.loss)
self.test_acc[iteration + 1] = ta
self.test_loss[iteration + 1] = tl
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
self.save_plot(
self.train_loss,
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_train_loss.png".format(self.rank)),
)
self.save_plot(
self.test_loss,
results_dict["test_loss"],
"test_loss",
"Testing Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_test_loss.png".format(self.rank)),
)
self.save_plot(
self.test_acc,
results_dict["test_acc"],
"test_acc",
"Testing Accuracy",
"Communication Rounds",
os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
)
with open(
os.path.join(log_dir, "{}_train_loss.json".format(self.rank)), "w"
) as of:
json.dump(self.train_loss, of)
with open(
os.path.join(log_dir, "{}_test_loss.json".format(self.rank)), "w"
) as of:
json.dump(self.test_loss, of)
with open(
os.path.join(log_dir, "{}_test_acc.json".format(self.rank)), "w"
) as of:
json.dump(self.test_acc, of)
with open(
os.path.join(log_dir, "{}_results.json".format(self.rank)), "w"
) as of:
json.dump(results_dict, of)
self.communication.disconnect_neighbors()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment