Newer
Older
from decentralizepy.sharing.Sharing import Sharing
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
self.communication_round = 0
assert len(self.model.accumulated_gradients) > 0
gradient_sum = self.model.accumulated_gradients[0]
for i in range(1, len(self.model.accumulated_gradients)):
for key in self.model.accumulated_gradients[i]:
gradient_sum[key] += self.model.accumulated_gradients[i][key]
logging.info("Returning topk gradients")
tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
return torch.topk(
G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
)
def serialized_model(self):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
if self.communication_round:
with open(
os.path.join(
self.log_dir, "{}_shared_params.json".format(self.rank)
),
"r",
) as inf:
shared_params = json.load(inf)
else:
shared_params = dict()
shared_params["order"] = self.model.state_dict().keys()
shapes = dict()
for k, v in self.model.state_dict.items():
shapes[k] = v.shape.tolist()
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)),
"w",
) as of:
json.dump(shared_params, of)
tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
T = torch.cat(tensors_to_cat, dim=0)
T_topk = T[G_topk]
if not self.dict_ordered:
raise NotImplementedError
m["indices"] = G_topk.numpy().tolist()
m["params"] = T_topk.numpy().tolist()
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Converted dictionary to json")
return m
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def deserialized_model(self, m):
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(json.loads(m["indices"]))
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(json.loads(m["params"]))
logging.debug("Final tensor: {}".format(T[index_tensor]))
start_index = 0
for i, key in enumerate(state_dict):
end_index = start_index + lens[i]
state_dict[key] = T[start_index:end_index].reshape(shapes[i])
start_index = end_index
return state_dict