Skip to content
Snippets Groups Projects
Sharing.py 5.18 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
import logging
import pickle
Rishi Sharma's avatar
Rishi Sharma committed
from collections import deque

Rishi Sharma's avatar
Rishi Sharma committed
import numpy
Rishi Sharma's avatar
Rishi Sharma committed
import torch

Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
class Sharing:
    """
    API defining who to share with and what, and what to do on receiving
Rishi Sharma's avatar
Rishi Sharma committed
    """
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
    def __init__(
        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
    ):
        """
        Constructor
        Parameters
        ----------
        rank : int
            Local rank
        machine_id : int
            Global machine id
        communication : decentralizepy.communication.Communication
            Communication module used to send and receive messages
        mapping : decentralizepy.mappings.Mapping
            Mapping (rank, machine_id) -> uid
        graph : decentralizepy.graphs.Graph
            Graph reprensenting neighbors
        model : decentralizepy.models.Model
            Model to train
        dataset : decentralizepy.datasets.Dataset
            Dataset for sharing data. Not implemented yer! TODO
        log_dir : str
            Location to write shared_params (only writing for 2 procs per machine)
Rishi Sharma's avatar
Rishi Sharma committed
        self.rank = rank
        self.machine_id = machine_id
        self.uid = mapping.get_uid(rank, machine_id)
        self.communication = communication
        self.mapping = mapping
        self.graph = graph
        self.model = model
        self.dataset = dataset
Rishi Sharma's avatar
Rishi Sharma committed
        self.communication_round = 0
Rishi Sharma's avatar
Rishi Sharma committed
        self.log_dir = log_dir
        self.total_data = 0
Rishi Sharma's avatar
Rishi Sharma committed

        self.peer_deques = dict()
        my_neighbors = self.graph.neighbors(self.uid)
        for n in my_neighbors:
            self.peer_deques[n] = deque()

    def received_from_all(self):
        """
        Check if all neighbors have sent the current iteration
        Returns
        -------
        bool
            True if required data has been received, False otherwise
Rishi Sharma's avatar
Rishi Sharma committed
        for _, i in self.peer_deques.items():
            if len(i) == 0:
                return False
        return True

    def get_neighbors(self, neighbors):
        """
        Choose which neighbors to share with
        Parameters
        ----------
        neighbors : list(int)
            List of all neighbors
        Returns
        -------
        list(int)
            Neighbors to share with
Rishi Sharma's avatar
Rishi Sharma committed
        # modify neighbors here
        return neighbors

    def serialized_model(self):
        Convert model to a dictionary. Here we can choose how much to share
        Returns
        -------
        dict
            Model converted to dict
Rishi Sharma's avatar
Rishi Sharma committed
        m = dict()
        for key, val in self.model.state_dict().items():
            m[key] = val.numpy()
            self.total_data += len(self.communication.encrypt(m[key]))
Rishi Sharma's avatar
Rishi Sharma committed
        return m

    def deserialized_model(self, m):
        Convert received dict to state_dict.
        Parameters
        ----------
        m : dict
            received dict
        Returns
        -------
        state_dict
            state_dict of received
Rishi Sharma's avatar
Rishi Sharma committed
        state_dict = dict()
        for key, value in m.items():
            state_dict[key] = torch.from_numpy(value)
Rishi Sharma's avatar
Rishi Sharma committed
        return state_dict

    def step(self):
Rishi Sharma's avatar
Rishi Sharma committed
        """
        Perform a sharing step. Implements D-PSGD.

        """
Rishi Sharma's avatar
Rishi Sharma committed
        data = self.serialized_model()
        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
        all_neighbors = self.graph.neighbors(my_uid)
        iter_neighbors = self.get_neighbors(all_neighbors)
Rishi Sharma's avatar
Rishi Sharma committed
        data["degree"] = len(all_neighbors)
Rishi Sharma's avatar
Rishi Sharma committed
        data["iteration"] = self.communication_round
Rishi Sharma's avatar
Rishi Sharma committed
        for neighbor in iter_neighbors:
            self.communication.send(neighbor, data)

        logging.info("Waiting for messages from neighbors")
Rishi Sharma's avatar
Rishi Sharma committed
        while not self.received_from_all():
            sender, data = self.communication.receive()
            logging.debug("Received model from {}".format(sender))
Rishi Sharma's avatar
Rishi Sharma committed
            degree = data["degree"]
Rishi Sharma's avatar
Rishi Sharma committed
            iteration = data["iteration"]
Rishi Sharma's avatar
Rishi Sharma committed
            del data["degree"]
Rishi Sharma's avatar
Rishi Sharma committed
            del data["iteration"]
Rishi Sharma's avatar
Rishi Sharma committed
            self.peer_deques[sender].append((degree, iteration, data))
Rishi Sharma's avatar
Rishi Sharma committed
            logging.info(
                "Deserialized received model from {} of iteration {}".format(
                    sender, iteration
                )
            )
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Starting model averaging after receiving from all neighbors")
Rishi Sharma's avatar
Rishi Sharma committed
        total = dict()
        weight_total = 0
Rishi Sharma's avatar
Rishi Sharma committed
        for i, n in enumerate(self.peer_deques):
Rishi Sharma's avatar
Rishi Sharma committed
            degree, iteration, data = self.peer_deques[n].popleft()
            logging.debug(
Rishi Sharma's avatar
Rishi Sharma committed
                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
Rishi Sharma's avatar
Rishi Sharma committed
            )
            data = self.deserialized_model(data)
Rishi Sharma's avatar
Rishi Sharma committed
            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
Rishi Sharma's avatar
Rishi Sharma committed
            weight_total += weight
            for key, value in data.items():
                if key in total:
                    total[key] += value * weight
                else:
                    total[key] = value * weight

        for key, value in self.model.state_dict().items():
Rishi Sharma's avatar
Rishi Sharma committed
            total[key] += (1 - weight_total) * value  # Metro-Hastings
Rishi Sharma's avatar
Rishi Sharma committed

        self.model.load_state_dict(total)
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Model averaging complete")
Rishi Sharma's avatar
Rishi Sharma committed

        self.communication_round += 1