Newer
Older
import logging
from decentralizepy.sharing.PartialModel import PartialModel
class ManualAdapt(PartialModel):
"""
This class implements the basic growing partial model sharing provided when and what alpha to set.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
change_alpha : list
List of alphas to set. change_alpha[0] must be initial alpha.
change_rounds : list
List of iterations to change alpha. len(change_alpha) = len(change_rounds) + 1.
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
assert change_alpha != ""
assert change_alpha != None
assert change_rounds != ""
assert change_rounds != None
if type(change_alpha) == str:
change_alpha = eval(change_alpha)
if type(change_rounds) == str:
change_rounds = eval(change_rounds)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
change_alpha[0],
dict_ordered,
save_shared,
metadata_cap,
)
self.change_alpha = change_alpha[1:]
self.change_rounds = change_rounds
def step(self):
"""
Perform a sharing step. Implements D-PSGD with alpha manually given.
"""
if (
len(self.change_rounds)
and (self.communication_round + 1) == self.change_rounds[0]
):
self.alpha = min(self.change_alpha[0], 1.00)
self.change_alpha = self.change_alpha[1:]
self.change_rounds = self.change_rounds[1:]
if self.alpha == 0.0:
logging.info("Not sending/receiving data (alpha=0.0)")
self.communication_round += 1
return
super().step()