Skip to content
Snippets Groups Projects
Commit 77a1296e authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Merge branch 'centralized_testing_fix' into 'main'

Only start star topology when needed

See merge request sacs/decentralizepy!13
parents bf5f03d1 2367e854
No related branches found
No related tags found
No related merge requests found
......@@ -243,6 +243,17 @@ class Node:
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
self.addresses_filepath = comm_params.get("addresses_filepath", None)
if self.centralized_test_eval:
self.testing_comm = TCP(
self.rank,
self.machine_id,
self.mapping,
self.star.n_procs,
self.addresses_filepath,
offset=self.star.n_procs,
)
self.testing_comm.connect_neighbors(self.star.neighbors(self.uid))
self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
......@@ -360,16 +371,6 @@ class Node:
self.testset = self.dataset.get_testset()
self.communication.connect_neighbors(self.graph.neighbors(self.uid))
rounds_to_test = self.test_after
testing_comm = TCP(
self.rank,
self.machine_id,
self.mapping,
self.star.n_procs,
self.addresses_filepath,
offset=self.star.n_procs,
)
testing_comm.connect_neighbors(self.star.neighbors(self.uid))
rounds_to_train_evaluate = self.train_evaluate_after
global_epoch = 1
change = 1
......@@ -387,19 +388,20 @@ class Node:
**dataset_params_copy
)
dataset = self.whole_dataset
tthelper = TrainTestHelper(
dataset, # self.whole_dataset,
# self.model_test, # todo: this only works if eval_train is set to false
self.model,
self.loss,
self.weights_store_dir,
self.mapping.get_n_procs(),
self.trainer,
testing_comm,
self.star,
self.threads_per_proc,
eval_train=self.centralized_train_eval,
)
if self.centralized_test_eval:
tthelper = TrainTestHelper(
dataset, # self.whole_dataset,
# self.model_test, # todo: this only works if eval_train is set to false
self.model,
self.loss,
self.weights_store_dir,
self.mapping.get_n_procs(),
self.trainer,
self.testing_comm,
self.star,
self.threads_per_proc,
eval_train=self.centralized_train_eval,
)
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
......@@ -475,8 +477,8 @@ class Node:
if trl is not None:
results_dict["train_loss"][iteration + 1] = trl
else:
testing_comm.send(0, self.model.get_weights())
sender, data = testing_comm.receive()
self.testing_comm.send(0, self.model.get_weights())
sender, data = self.testing_comm.receive()
assert sender == 0 and data == "finished"
else:
logging.info("Evaluating on test set.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment