import json import logging from collections import deque import numpy import torch class Sharing: """ API defining who to share with and what, and what to do on receiving """ def __init__( self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir ): self.rank = rank self.machine_id = machine_id self.uid = mapping.get_uid(rank, machine_id) self.communication = communication self.mapping = mapping self.graph = graph self.model = model self.dataset = dataset self.log_dir = log_dir self.peer_deques = dict() my_neighbors = self.graph.neighbors(self.uid) for n in my_neighbors: self.peer_deques[n] = deque() def received_from_all(self): for _, i in self.peer_deques.items(): if len(i) == 0: return False return True def get_neighbors(self, neighbors): # modify neighbors here return neighbors def serialized_model(self): m = dict() for key, val in self.model.state_dict().items(): m[key] = json.dumps(val.numpy().tolist()) return m def deserialized_model(self, m): state_dict = dict() for key, value in m.items(): state_dict[key] = torch.from_numpy(numpy.array(json.loads(value))) return state_dict def step(self): data = self.serialized_model() my_uid = self.mapping.get_uid(self.rank, self.machine_id) all_neighbors = self.graph.neighbors(my_uid) iter_neighbors = self.get_neighbors(all_neighbors) data["degree"] = len(all_neighbors) for neighbor in iter_neighbors: self.communication.send(neighbor, data) logging.info("Waiting for messages from neighbors") while not self.received_from_all(): sender, data = self.communication.receive() logging.debug("Received model from {}".format(sender)) degree = data["degree"] del data["degree"] self.peer_deques[sender].append((degree, self.deserialized_model(data))) logging.debug("Deserialized received model from {}".format(sender)) logging.info("Starting model averaging after receiving from all neighbors") total = dict() weight_total = 0 for i, n in enumerate(self.peer_deques): logging.debug("Averaging model from neighbor {}".format(i)) degree, data = self.peer_deques[n].popleft() weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings weight_total += weight for key, value in data.items(): if key in total: total[key] += value * weight else: total[key] = value * weight for key, value in self.model.state_dict().items(): total[key] += (1 - weight_total) * value # Metro-Hastings self.model.load_state_dict(total) logging.info("Model averaging complete")