import importlib import json import logging import os from matplotlib import pyplot as plt from decentralizepy import utils from decentralizepy.communication.Communication import Communication from decentralizepy.graphs.Graph import Graph from decentralizepy.mappings.Mapping import Mapping class Node: """ This class defines the node (entity that performs learning, sharing and communication). """ def save_plot(self, l, label, title, xlabel, filename): plt.clf() y_axis = [l[key] for key in l.keys()] x_axis = list(map(int, l.keys())) plt.plot(x_axis, y_axis, label=label) plt.xlabel(xlabel) plt.title(title) plt.savefig(filename) def __init__( self, rank: int, machine_id: int, mapping: Mapping, graph: Graph, config, iterations=1, log_dir=".", log_level=logging.INFO, test_after=5, *args ): """ Constructor Parameters ---------- rank : int Rank of process local to the machine machine_id : int Machine ID on which the process in running n_procs_local : int Number of processes on current machine mapping : decentralizepy.mappings The object containing the mapping rank <--> uid graph : decentralizepy.graphs The object containing the global graph config : dict A dictionary of configurations. Must contain the following: [DATASET] dataset_package dataset_class model_class [OPTIMIZER_PARAMS] optimizer_package optimizer_class [TRAIN_PARAMS] training_package = decentralizepy.training.Training training_class = Training epochs_per_round = 25 batch_size = 64 log_dir : str Logging directory log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL args : optional Other arguments """ log_file = os.path.join(log_dir, str(rank) + ".log") logging.basicConfig( filename=log_file, format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s", level=log_level, force=True, ) logging.info("Started process.") self.rank = rank self.machine_id = machine_id self.graph = graph self.mapping = mapping self.uid = self.mapping.get_uid(rank, machine_id) self.log_dir = log_dir logging.debug("Rank: %d", self.rank) logging.debug("type(graph): %s", str(type(self.rank))) logging.debug("type(mapping): %s", str(type(self.mapping))) dataset_configs = config["DATASET"] dataset_module = importlib.import_module(dataset_configs["dataset_package"]) dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) dataset_params = utils.remove_keys( dataset_configs, ["dataset_package", "dataset_class", "model_class"] ) self.dataset = dataset_class(rank, **dataset_params) logging.info("Dataset instantiation complete.") model_class = getattr(dataset_module, dataset_configs["model_class"]) self.model = model_class() optimizer_configs = config["OPTIMIZER_PARAMS"] optimizer_module = importlib.import_module( optimizer_configs["optimizer_package"] ) optimizer_class = getattr( optimizer_module, optimizer_configs["optimizer_class"] ) optimizer_params = utils.remove_keys( optimizer_configs, ["optimizer_package", "optimizer_class"] ) self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params) train_configs = config["TRAIN_PARAMS"] train_module = importlib.import_module(train_configs["training_package"]) train_class = getattr(train_module, train_configs["training_class"]) loss_package = importlib.import_module(train_configs["loss_package"]) if "loss_class" in train_configs.keys(): loss_class = getattr(loss_package, train_configs["loss_class"]) self.loss = loss_class() else: self.loss = getattr(loss_package, train_configs["loss"]) train_params = utils.remove_keys( train_configs, [ "training_package", "training_class", "loss", "loss_package", "loss_class", ], ) self.trainer = train_class( self.model, self.optimizer, self.loss, **train_params ) comm_configs = config["COMMUNICATION"] comm_module = importlib.import_module(comm_configs["comm_package"]) comm_class = getattr(comm_module, comm_configs["comm_class"]) comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"]) self.communication = comm_class( self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params ) self.communication.connect_neighbors(self.graph.neighbors(self.uid)) sharing_configs = config["SHARING"] sharing_package = importlib.import_module(sharing_configs["sharing_package"]) sharing_class = getattr(sharing_package, sharing_configs["sharing_class"]) sharing_params = utils.remove_keys( sharing_configs, ["sharing_package", "sharing_class"] ) self.sharing = sharing_class( self.rank, self.machine_id, self.communication, self.mapping, self.graph, self.model, self.dataset, self.log_dir, **sharing_params ) self.testset = self.dataset.get_testset() rounds_to_test = test_after for iteration in range(iterations): logging.info("Starting training iteration: %d", iteration) self.trainer.train(self.dataset) self.sharing.step() self.optimizer = optimizer_class( self.model.parameters(), **optimizer_params ) # Reset optimizer state self.trainer.reset_optimizer(self.optimizer) loss_after_sharing = self.trainer.eval_loss(self.dataset) if iteration: with open( os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "r", ) as inf: results_dict = json.load(inf) else: results_dict = {"train_loss": {}, "test_loss": {}, "test_acc": {}} results_dict["train_loss"][iteration + 1] = loss_after_sharing self.save_plot( results_dict["train_loss"], "train_loss", "Training Loss", "Communication Rounds", os.path.join(log_dir, "{}_train_loss.png".format(self.rank)), ) rounds_to_test -= 1 if self.dataset.__testing__ and rounds_to_test == 0: logging.info("Evaluating on test set.") rounds_to_test = test_after ta, tl = self.dataset.test(self.model, self.loss) results_dict["test_acc"][iteration + 1] = ta results_dict["test_loss"][iteration + 1] = tl self.save_plot( results_dict["test_loss"], "test_loss", "Testing Loss", "Communication Rounds", os.path.join(log_dir, "{}_test_loss.png".format(self.rank)), ) self.save_plot( results_dict["test_acc"], "test_acc", "Testing Accuracy", "Communication Rounds", os.path.join(log_dir, "{}_test_acc.png".format(self.rank)), ) with open( os.path.join(log_dir, "{}_results.json".format(self.rank)), "w" ) as of: json.dump(results_dict, of) self.communication.disconnect_neighbors()