import importlib import logging import os from decentralizepy import utils from decentralizepy.graphs.Graph import Graph from decentralizepy.mappings.Mapping import Mapping class Node: """ This class defines the node (entity that performs learning, sharing and communication). """ def __init__( self, rank: int, machine_id: int, mapping: Mapping, graph: Graph, config, iterations=1, log_dir=".", log_level=logging.INFO, *args ): """ Constructor Parameters ---------- rank : int Rank of process local to the machine machine_id : int Machine ID on which the process in running n_procs_local : int Number of processes on current machine mapping : decentralizepy.mappings The object containing the mapping rank <--> uid graph : decentralizepy.graphs The object containing the global graph config : dict A dictionary of configurations. Must contain the following: [DATASET] dataset_package dataset_class model_class [OPTIMIZER_PARAMS] optimizer_package optimizer_class [TRAIN_PARAMS] training_package = decentralizepy.training.Training training_class = Training epochs_per_round = 25 batch_size = 64 log_dir : str Logging directory log_level : logging.Level One of DEBUG, INFO, WARNING, ERROR, CRITICAL args : optional Other arguments """ log_file = os.path.join(log_dir, str(rank) + ".log") logging.basicConfig( filename=log_file, format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s", level=log_level, force=True, ) logging.info("Started process.") self.rank = rank self.graph = graph self.mapping = mapping logging.debug("Rank: %d", self.rank) logging.debug("type(graph): %s", str(type(self.rank))) logging.debug("type(mapping): %s", str(type(self.mapping))) dataset_configs = config["DATASET"] dataset_module = importlib.import_module(dataset_configs["dataset_package"]) dataset_class = getattr(dataset_module, dataset_configs["dataset_class"]) dataset_params = utils.remove_keys( dataset_configs, ["dataset_package", "dataset_class", "model_class"] ) self.dataset = dataset_class(rank, **dataset_params) logging.info("Dataset instantiation complete.") model_class = getattr(dataset_module, dataset_configs["model_class"]) self.model = model_class() optimizer_configs = config["OPTIMIZER_PARAMS"] optimizer_module = importlib.import_module( optimizer_configs["optimizer_package"] ) optimizer_class = getattr( optimizer_module, optimizer_configs["optimizer_class"] ) optimizer_params = utils.remove_keys( optimizer_configs, ["optimizer_package", "optimizer_class"] ) self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params) train_configs = config["TRAIN_PARAMS"] train_module = importlib.import_module(train_configs["training_package"]) train_class = getattr(train_module, train_configs["training_class"]) loss_package = importlib.import_module(train_configs["loss_package"]) if "loss_class" in train_configs.keys(): loss_class = getattr(loss_package, train_configs["loss_class"]) loss = loss_class() else: loss = getattr(loss_package, train_configs["loss"]) train_params = utils.remove_keys( train_configs, [ "training_package", "training_class", "loss", "loss_package", "loss_class", ], ) self.trainer = train_class(self.model, self.optimizer, loss, **train_params) self.testset = self.dataset.get_trainset() for iteration in range(iterations): logging.info("Starting training iteration: %d", iteration) self.trainer.train(self.dataset) if self.dataset.__testing__: self.dataset.test(self.model)