Skip to content
Snippets Groups Projects
Sharing.py 3.02 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
import json
import logging
Rishi Sharma's avatar
Rishi Sharma committed
from collections import deque

Rishi Sharma's avatar
Rishi Sharma committed
import numpy
Rishi Sharma's avatar
Rishi Sharma committed
import torch

Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
class Sharing:
    """
    API defining who to share with and what, and what to do on receiving
    """
Rishi Sharma's avatar
Rishi Sharma committed

    def __init__(
        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
    ):
Rishi Sharma's avatar
Rishi Sharma committed
        self.rank = rank
        self.machine_id = machine_id
        self.uid = mapping.get_uid(rank, machine_id)
        self.communication = communication
        self.mapping = mapping
        self.graph = graph
        self.model = model
        self.dataset = dataset
        self.log_dir = log_dir
Rishi Sharma's avatar
Rishi Sharma committed

        self.peer_deques = dict()
        my_neighbors = self.graph.neighbors(self.uid)
        for n in my_neighbors:
            self.peer_deques[n] = deque()

    def received_from_all(self):
        for _, i in self.peer_deques.items():
            if len(i) == 0:
                return False
        return True

    def get_neighbors(self, neighbors):
        # modify neighbors here
        return neighbors

    def serialized_model(self):
        m = dict()
        for key, val in self.model.state_dict().items():
            m[key] = json.dumps(val.numpy().tolist())
        return m

    def deserialized_model(self, m):
        state_dict = dict()
        for key, value in m.items():
            state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
        return state_dict

    def step(self):
        data = self.serialized_model()
        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
        all_neighbors = self.graph.neighbors(my_uid)
        iter_neighbors = self.get_neighbors(all_neighbors)
Rishi Sharma's avatar
Rishi Sharma committed
        data["degree"] = len(all_neighbors)
Rishi Sharma's avatar
Rishi Sharma committed
        for neighbor in iter_neighbors:
            self.communication.send(neighbor, data)

        logging.info("Waiting for messages from neighbors")
Rishi Sharma's avatar
Rishi Sharma committed
        while not self.received_from_all():
            sender, data = self.communication.receive()
            logging.debug("Received model from {}".format(sender))
Rishi Sharma's avatar
Rishi Sharma committed
            degree = data["degree"]
            del data["degree"]
            self.peer_deques[sender].append((degree, self.deserialized_model(data)))
            logging.debug("Deserialized received model from {}".format(sender))
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Starting model averaging after receiving from all neighbors")
Rishi Sharma's avatar
Rishi Sharma committed
        total = dict()
        weight_total = 0
Rishi Sharma's avatar
Rishi Sharma committed
        for i, n in enumerate(self.peer_deques):
            logging.debug("Averaging model from neighbor {}".format(i))
Rishi Sharma's avatar
Rishi Sharma committed
            degree, data = self.peer_deques[n].popleft()
Rishi Sharma's avatar
Rishi Sharma committed
            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
Rishi Sharma's avatar
Rishi Sharma committed
            weight_total += weight
            for key, value in data.items():
                if key in total:
                    total[key] += value * weight
                else:
                    total[key] = value * weight

        for key, value in self.model.state_dict().items():
Rishi Sharma's avatar
Rishi Sharma committed
            total[key] += (1 - weight_total) * value  # Metro-Hastings
Rishi Sharma's avatar
Rishi Sharma committed

        self.model.load_state_dict(total)
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Model averaging complete")