Newer
Older
from collections import deque
import json
import logging
import torch
import numpy
class Sharing:
"""
API defining who to share with and what, and what to do on receiving
"""
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset):
self.rank = rank
self.machine_id = machine_id
self.uid = mapping.get_uid(rank, machine_id)
self.communication = communication
self.mapping = mapping
self.graph = graph
self.model = model
self.dataset = dataset
self.peer_deques = dict()
my_neighbors = self.graph.neighbors(self.uid)
for n in my_neighbors:
self.peer_deques[n] = deque()
def received_from_all(self):
for _, i in self.peer_deques.items():
if len(i) == 0:
return False
return True
def get_neighbors(self, neighbors):
# modify neighbors here
return neighbors
def serialized_model(self):
m = dict()
for key, val in self.model.state_dict().items():
m[key] = json.dumps(val.numpy().tolist())
return m
def deserialized_model(self, m):
state_dict = dict()
for key, value in m.items():
state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
return state_dict
def step(self):
data = self.serialized_model()
my_uid = self.mapping.get_uid(self.rank, self.machine_id)
all_neighbors = self.graph.neighbors(my_uid)
iter_neighbors = self.get_neighbors(all_neighbors)
data['degree'] = len(all_neighbors)
for neighbor in iter_neighbors:
self.communication.send(neighbor, data)
while not self.received_from_all():
sender, data = self.communication.receive()
logging.info("Received model from {}".format(sender))
degree = data["degree"]
del data["degree"]
self.peer_deques[sender].append((degree, self.deserialized_model(data)))
logging.info("Starting model averaging after receiving from all neighbors")
for i, n in enumerate(self.peer_deques):
logging.debug("Averaging model from neighbor {}".format(i))
degree, data = self.peer_deques[n].popleft()
weight = 1/(max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
for key, value in data.items():
if key in total:
total[key] += value * weight
else:
total[key] = value * weight
for key, value in self.model.state_dict().items():
total[key] += (1 - weight_total) * value # Metro-Hastings
self.model.load_state_dict(total)