diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py index dd1b9d9bbe8eee4e4bc28637ac66c9adff64768d..91f34e5ff4cb8540315c23e22fa951f2df7eac76 100644 --- a/src/decentralizepy/node/Node.py +++ b/src/decentralizepy/node/Node.py @@ -243,6 +243,17 @@ class Node: comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) self.addresses_filepath = comm_params.get("addresses_filepath", None) + if self.centralized_test_eval: + self.testing_comm = TCP( + self.rank, + self.machine_id, + self.mapping, + self.star.n_procs, + self.addresses_filepath, + offset=self.star.n_procs, + ) + self.testing_comm.connect_neighbors(self.star.neighbors(self.uid)) + self.communication = comm_class( self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params ) @@ -360,16 +371,6 @@ class Node: self.testset = self.dataset.get_testset() self.communication.connect_neighbors(self.graph.neighbors(self.uid)) rounds_to_test = self.test_after - - testing_comm = TCP( - self.rank, - self.machine_id, - self.mapping, - self.star.n_procs, - self.addresses_filepath, - offset=self.star.n_procs, - ) - testing_comm.connect_neighbors(self.star.neighbors(self.uid)) rounds_to_train_evaluate = self.train_evaluate_after global_epoch = 1 change = 1 @@ -387,19 +388,20 @@ class Node: **dataset_params_copy ) dataset = self.whole_dataset - tthelper = TrainTestHelper( - dataset, # self.whole_dataset, - # self.model_test, # todo: this only works if eval_train is set to false - self.model, - self.loss, - self.weights_store_dir, - self.mapping.get_n_procs(), - self.trainer, - testing_comm, - self.star, - self.threads_per_proc, - eval_train=self.centralized_train_eval, - ) + if self.centralized_test_eval: + tthelper = TrainTestHelper( + dataset, # self.whole_dataset, + # self.model_test, # todo: this only works if eval_train is set to false + self.model, + self.loss, + self.weights_store_dir, + self.mapping.get_n_procs(), + self.trainer, + self.testing_comm, + self.star, + self.threads_per_proc, + eval_train=self.centralized_train_eval, + ) for iteration in range(self.iterations): logging.info("Starting training iteration: %d", iteration) @@ -475,8 +477,8 @@ class Node: if trl is not None: results_dict["train_loss"][iteration + 1] = trl else: - testing_comm.send(0, self.model.get_weights()) - sender, data = testing_comm.receive() + self.testing_comm.send(0, self.model.get_weights()) + sender, data = self.testing_comm.receive() assert sender == 0 and data == "finished" else: logging.info("Evaluating on test set.")