Skip to content
Snippets Groups Projects
Node.py 5.76 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
import importlib
import logging
import os

Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy import utils
Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy.communication.Communication import Communication
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
Rishi Sharma's avatar
Rishi Sharma committed
class Node:
    """
    This class defines the node (entity that performs learning, sharing and communication).
    """
Rishi Sharma's avatar
Rishi Sharma committed

    def __init__(
        self,
        rank: int,
Rishi Sharma's avatar
Rishi Sharma committed
        machine_id: int,
        mapping: Mapping,
        graph: Graph,
        config,
Rishi Sharma's avatar
Rishi Sharma committed
        iterations=1,
        log_dir=".",
        log_level=logging.INFO,
Rishi Sharma's avatar
Rishi Sharma committed
        test_after=5,
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Constructor
        Parameters
        ----------
        rank : int
            Rank of process local to the machine
Rishi Sharma's avatar
Rishi Sharma committed
        machine_id : int
            Machine ID on which the process in running
        n_procs_local : int
            Number of processes on current machine
Rishi Sharma's avatar
Rishi Sharma committed
        mapping : decentralizepy.mappings
            The object containing the mapping rank <--> uid
        graph : decentralizepy.graphs
            The object containing the global graph
        config : dict
            A dictionary of configurations. Must contain the following:
            [DATASET]
                dataset_package
                dataset_class
                model_class
            [OPTIMIZER_PARAMS]
                optimizer_package
                optimizer_class
            [TRAIN_PARAMS]
                training_package = decentralizepy.training.Training
                training_class = Training
                epochs_per_round = 25
                batch_size = 64
        log_dir : str
            Logging directory
        log_level : logging.Level
            One of DEBUG, INFO, WARNING, ERROR, CRITICAL
Rishi Sharma's avatar
Rishi Sharma committed
        """
        log_file = os.path.join(log_dir, str(rank) + ".log")
        logging.basicConfig(
            filename=log_file,
            format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
            level=log_level,
            force=True,
        )

        logging.info("Started process.")

Rishi Sharma's avatar
Rishi Sharma committed
        self.rank = rank
Rishi Sharma's avatar
Rishi Sharma committed
        self.machine_id = machine_id
Rishi Sharma's avatar
Rishi Sharma committed
        self.graph = graph
        self.mapping = mapping
Rishi Sharma's avatar
Rishi Sharma committed
        self.uid = self.mapping.get_uid(rank, machine_id)

        logging.debug("Rank: %d", self.rank)
        logging.debug("type(graph): %s", str(type(self.rank)))
        logging.debug("type(mapping): %s", str(type(self.mapping)))
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        dataset_configs = config["DATASET"]
        dataset_module = importlib.import_module(dataset_configs["dataset_package"])
        dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
Rishi Sharma's avatar
Rishi Sharma committed
        dataset_params = utils.remove_keys(
            dataset_configs, ["dataset_package", "dataset_class", "model_class"]
        )
        self.dataset = dataset_class(rank, **dataset_params)
        logging.info("Dataset instantiation complete.")

        model_class = getattr(dataset_module, dataset_configs["model_class"])
        self.model = model_class()

Rishi Sharma's avatar
Rishi Sharma committed
        optimizer_configs = config["OPTIMIZER_PARAMS"]
Rishi Sharma's avatar
Rishi Sharma committed
        optimizer_module = importlib.import_module(
            optimizer_configs["optimizer_package"]
        )
        optimizer_class = getattr(
            optimizer_module, optimizer_configs["optimizer_class"]
        )
        optimizer_params = utils.remove_keys(
            optimizer_configs, ["optimizer_package", "optimizer_class"]
        )
        self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params)
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        train_configs = config["TRAIN_PARAMS"]
        train_module = importlib.import_module(train_configs["training_package"])
        train_class = getattr(train_module, train_configs["training_class"])
Rishi Sharma's avatar
Rishi Sharma committed

        loss_package = importlib.import_module(train_configs["loss_package"])
Rishi Sharma's avatar
Rishi Sharma committed
        if "loss_class" in train_configs.keys():
            loss_class = getattr(loss_package, train_configs["loss_class"])
            loss = loss_class()
        else:
            loss = getattr(loss_package, train_configs["loss"])
Rishi Sharma's avatar
Rishi Sharma committed
        train_params = utils.remove_keys(
            train_configs,
            [
                "training_package",
                "training_class",
                "loss",
                "loss_package",
                "loss_class",
            ],
        )
Rishi Sharma's avatar
Rishi Sharma committed
        self.trainer = train_class(self.model, self.optimizer, loss, **train_params)
Rishi Sharma's avatar
Rishi Sharma committed
        comm_configs = config["COMMUNICATION"]
        comm_module = importlib.import_module(comm_configs["comm_package"])
        comm_class = getattr(comm_module, comm_configs["comm_class"])
        comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
Rishi Sharma's avatar
Rishi Sharma committed
        self.communication = comm_class(
            self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
        )
Rishi Sharma's avatar
Rishi Sharma committed
        self.communication.connect_neighbors(self.graph.neighbors(self.uid))
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        sharing_configs = config["SHARING"]
        sharing_package = importlib.import_module(sharing_configs["sharing_package"])
        sharing_class = getattr(sharing_package, sharing_configs["sharing_class"])
Rishi Sharma's avatar
Rishi Sharma committed
        self.sharing = sharing_class(
            self.rank,
            self.machine_id,
            self.communication,
            self.mapping,
            self.graph,
            self.model,
            self.dataset,
        )
Rishi Sharma's avatar
Rishi Sharma committed

        self.testset = self.dataset.get_testset()
Rishi Sharma's avatar
Rishi Sharma committed
        rounds_to_test = test_after
Rishi Sharma's avatar
Rishi Sharma committed

        for iteration in range(iterations):
Rishi Sharma's avatar
Rishi Sharma committed
            logging.info("Starting training iteration: %d", iteration)
Rishi Sharma's avatar
Rishi Sharma committed
            self.trainer.train(self.dataset)
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
            self.sharing.step()
Rishi Sharma's avatar
Rishi Sharma committed

            rounds_to_test -= 1
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
            if self.dataset.__testing__ and rounds_to_test == 0:
                rounds_to_test = test_after
Rishi Sharma's avatar
Rishi Sharma committed
                self.dataset.test(self.model)
Rishi Sharma's avatar
Rishi Sharma committed

        self.communication.disconnect_neighbors()