"README.md" did not exist on "c8dd50b48eb3367650bfbd11f9cf8ceb62e91662"
Newer
Older
from decentralizepy.sharing.Sharing import Sharing
"""
This class implements the vanilla version of partial model sharing.
"""
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
if rank == 0 or rank == 1:
self.save_shared = True
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
assert len(self.model.accumulated_gradients) > 0
gradient_sum = self.model.accumulated_gradients[0]
for i in range(1, len(self.model.accumulated_gradients)):
for key in self.model.accumulated_gradients[i]:
gradient_sum[key] += self.model.accumulated_gradients[i][key]
logging.info("Returning topk gradients")
tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
std, mean = torch.std_mean(G_topk, unbiased=False)
self.std = std.item()
self.mean = mean.item()
return torch.topk(
G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
)
Convert model to a dict. self.alpha specifies the fraction of model to send.
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
T = torch.cat(tensors_to_cat, dim=0)
T_topk = T[G_topk]
if not self.dict_ordered:
raise NotImplementedError
m["indices"] = G_topk.numpy()
m["params"] = T_topk.numpy()
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Converted dictionary to pickle")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"]))
Returns
-------
state_dict
state_dict of received
"""
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor]))
start_index = 0
for i, key in enumerate(state_dict):
end_index = start_index + lens[i]
state_dict[key] = T[start_index:end_index].reshape(shapes[i])
start_index = end_index
return state_dict