Skip to content
Snippets Groups Projects
Sharing.py 2.85 KiB
Newer Older
Rishi Sharma's avatar
Rishi Sharma committed
from collections import deque
import json
import logging
import torch
import numpy

Rishi Sharma's avatar
Rishi Sharma committed
class Sharing:
    """
    API defining who to share with and what, and what to do on receiving
    """
Rishi Sharma's avatar
Rishi Sharma committed

Rishi Sharma's avatar
Rishi Sharma committed
    def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset):
        self.rank = rank
        self.machine_id = machine_id
        self.uid = mapping.get_uid(rank, machine_id)
        self.communication = communication
        self.mapping = mapping
        self.graph = graph
        self.model = model
        self.dataset = dataset

        self.peer_deques = dict()
        my_neighbors = self.graph.neighbors(self.uid)
        for n in my_neighbors:
            self.peer_deques[n] = deque()

    def received_from_all(self):
        for _, i in self.peer_deques.items():
            if len(i) == 0:
                return False
        return True
        

    def get_neighbors(self, neighbors):
        # modify neighbors here
        return neighbors

    def serialized_model(self):
        m = dict()
        for key, val in self.model.state_dict().items():
            m[key] = json.dumps(val.numpy().tolist())
        return m

    def deserialized_model(self, m):
        state_dict = dict()
        for key, value in m.items():
            state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
        return state_dict
 

    def step(self):
        data = self.serialized_model()
        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
        all_neighbors = self.graph.neighbors(my_uid)
        iter_neighbors = self.get_neighbors(all_neighbors)
        data['degree'] = len(all_neighbors)
        for neighbor in iter_neighbors:
            self.communication.send(neighbor, data)

        while not self.received_from_all():
            sender, data = self.communication.receive()
            logging.info("Received model from {}".format(sender))
            degree = data["degree"]
            del data["degree"]
            self.peer_deques[sender].append((degree, self.deserialized_model(data)))
        
Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Starting model averaging after receiving from all neighbors")
Rishi Sharma's avatar
Rishi Sharma committed
        total = dict()
        weight_total = 0
Rishi Sharma's avatar
Rishi Sharma committed
        for i, n in enumerate(self.peer_deques):
            logging.debug("Averaging model from neighbor {}".format(i))
Rishi Sharma's avatar
Rishi Sharma committed
            degree, data = self.peer_deques[n].popleft()
            weight = 1/(max(len(self.peer_deques), degree) + 1) # Metro-Hastings
            weight_total += weight
            for key, value in data.items():
                if key in total:
                    total[key] += value * weight
                else:
                    total[key] = value * weight

        for key, value in self.model.state_dict().items():
            total[key] += (1 - weight_total) * value # Metro-Hastings

        self.model.load_state_dict(total)
Rishi Sharma's avatar
Rishi Sharma committed

        logging.info("Model averaging complete")