Newer
Older
class Sharing:
"""
API defining who to share with and what, and what to do on receiving
def __init__(
self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
):
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yer! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
self.rank = rank
self.machine_id = machine_id
self.uid = mapping.get_uid(rank, machine_id)
self.communication = communication
self.mapping = mapping
self.graph = graph
self.model = model
self.dataset = dataset
self.peer_deques = dict()
my_neighbors = self.graph.neighbors(self.uid)
for n in my_neighbors:
self.peer_deques[n] = deque()
def received_from_all(self):
"""
Check if all neighbors have sent the current iteration
Returns
-------
bool
True if required data has been received, False otherwise
for _, i in self.peer_deques.items():
if len(i) == 0:
return False
return True
def get_neighbors(self, neighbors):
"""
Choose which neighbors to share with
Parameters
----------
neighbors : list(int)
List of all neighbors
Returns
-------
list(int)
Neighbors to share with
# modify neighbors here
return neighbors
def serialized_model(self):
Convert model to a dictionary. Here we can choose how much to share
self.total_data += len(self.communication.encrypt(m[key]))
Parameters
----------
m : dict
Returns
-------
state_dict
state_dict of received
"""
Perform a sharing step. Implements D-PSGD.
"""
data = self.serialized_model()
my_uid = self.mapping.get_uid(self.rank, self.machine_id)
all_neighbors = self.graph.neighbors(my_uid)
iter_neighbors = self.get_neighbors(all_neighbors)
for neighbor in iter_neighbors:
self.communication.send(neighbor, data)
logging.info("Waiting for messages from neighbors")
while not self.received_from_all():
sender, data = self.communication.receive()
logging.debug("Received model from {}".format(sender))
logging.info(
"Deserialized received model from {} of iteration {}".format(
sender, iteration
)
)
logging.info("Starting model averaging after receiving from all neighbors")
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(n, iteration)
data = self.deserialized_model(data)
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
for key, value in data.items():
if key in total:
total[key] += value * weight
else:
total[key] = value * weight
for key, value in self.model.state_dict().items():
total[key] += (1 - weight_total) * value # Metro-Hastings