Skip to content
Snippets Groups Projects
Commit da846d2a authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Update random alpha

parent 1586bbba
No related branches found
No related tags found
No related merge requests found
......@@ -201,6 +201,8 @@ class PartialModel(Sharing):
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["indices"] = G_topk.numpy().astype(np.int32)
m["params"] = T_topk.numpy()
......
import random
from decentralizepy.sharing.PartialModel import PartialModel
from decentralizepy.utils import identity
class RandomAlpha(PartialModel):
......@@ -19,9 +20,14 @@ class RandomAlpha(PartialModel):
model,
dataset,
log_dir,
alpha_list=[0.1,0.2,0.3,0.4,1.0],
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
accumulation=False,
save_accumulated="",
change_transformer=identity,
accumulate_averaging_changes=False,
):
"""
Constructor
......@@ -65,6 +71,14 @@ class RandomAlpha(PartialModel):
dict_ordered,
save_shared,
metadata_cap,
accumulation,
save_accumulated,
change_transformer,
accumulate_averaging_changes
)
self.alpha_list = eval(alpha_list)
random.seed(
self.mapping.get_uid(self.rank, self.machine_id)
)
def step(self):
......@@ -72,8 +86,5 @@ class RandomAlpha(PartialModel):
Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
"""
random.seed(
self.mapping.get_uid(self.rank, self.machine_id) + self.communication_round
)
self.alpha = random.randint(1, 7) / 10.0
self.alpha = random.choice(self.alpha_list)
super().step()
import random
from decentralizepy.sharing.Wavelet import Wavelet
class RandomAlpha(Wavelet):
"""
This class implements the partial model sharing with a random alpha each iteration.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha_list=[0.1,0.2,0.3,0.4,1.0],
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
wavelet="haar",
level=4,
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes=False,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
1.0,
dict_ordered,
save_shared,
metadata_cap,
wavelet,
level,
change_based_selection,
save_accumulated,
accumulation,
accumulate_averaging_changes,
)
self.alpha_list = eval(alpha_list)
random.seed(
self.mapping.get_uid(self.rank, self.machine_id)
)
def step(self):
"""
Perform a sharing step. Implements D-PSGD with alpha randomly chosen.
"""
self.alpha = random.choice(self.alpha_list)
super().step()
......@@ -179,7 +179,7 @@ class Wavelet(PartialModel):
Model converted to json dict
"""
if self.alpha > self.metadata_cap: # Share fully
if self.alpha >= self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
......@@ -218,6 +218,8 @@ class Wavelet(PartialModel):
m["indices"] = indices.numpy().astype(np.int32)
m["send_partial"] = True
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["alpha"])
......@@ -240,7 +242,7 @@ class Wavelet(PartialModel):
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
if "send_partial" not in m:
return super().deserialized_model(m)
with torch.no_grad():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment