Skip to content
Snippets Groups Projects
Commit e36ca70d authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Add ManualAdapt

parent 52720200
No related branches found
No related tags found
No related merge requests found
import logging
from decentralizepy.sharing.PartialModel import PartialModel
class ManualAdapt(PartialModel):
"""
This class implements the basic growing partial model sharing provided when and what alpha to set.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
change_alpha: list,
change_rounds: list,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
change_alpha : list
List of alphas to set. change_alpha[0] must be initial alpha.
change_rounds : list
List of iterations to change alpha. len(change_alpha) = len(change_rounds) + 1.
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
change_alpha[0],
dict_ordered,
save_shared,
metadata_cap,
)
self.change_alpha = change_alpha[1:]
self.change_rounds = change_rounds
def step(self):
"""
Perform a sharing step. Implements D-PSGD with alpha manually given.
"""
if (
len(self.change_rounds)
and (self.communication_round + 1) == self.change_rounds[0]
):
self.alpha = min(self.change_alpha[0], 1.00)
self.change_alpha = self.change_alpha[1:]
self.change_rounds = self.change_rounds[1:]
if self.alpha == 0.0:
logging.info("Not sending/receiving data (alpha=0.0)")
self.communication_round += 1
return
super().step()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment