Newer
Older
import json
import logging
import os
from pathlib import Path
from time import time
import torch
import torch.fft as fft
Jeffrey Wigger
committed
from decentralizepy.sharing.PartialModel import PartialModel
Jeffrey Wigger
committed
def change_transformer_fft(x):
"""
Transforms the model changes into frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
Returns
-------
x : torch.Tensor
Representation of the change int the frequency domain
"""
return fft.rfft(x)
Jeffrey Wigger
committed
class FFT(PartialModel):
"""
This class implements the fft version of model sharing
It is based on PartialModel.py
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
change_based_selection=True,
Jeffrey Wigger
committed
save_accumulated="",
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
change_based_selection : bool
use frequency change to select topk frequencies
Jeffrey Wigger
committed
save_accumulated : bool
True if accumulated weight change in the frequency domain should be written to file. In case of accumulation
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha,
dict_ordered,
save_shared,
metadata_cap,
accumulation,
save_accumulated,
change_transformer_fft,
accumulate_averaging_changes,
)
self.change_based_selection = change_based_selection
def apply_fft(self):
"""
Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected fft frequencies (complex numbers), b: Their indices.
"""
logging.info("Returning fft compressed model weights")
Jeffrey Wigger
committed
with torch.no_grad():
flat_fft = self.pre_share_model_transformed
Jeffrey Wigger
committed
if self.change_based_selection:
diff = self.model.model_change
_, index = torch.topk(
diff.abs(), round(self.alpha * len(diff)), dim=0, sorted=False
)
else:
_, index = torch.topk(
flat_fft.abs(),
round(self.alpha * len(flat_fft)),
dim=0,
sorted=False,
Jeffrey Wigger
committed
)
return flat_fft[index], index
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
m = dict()
if self.alpha >= self.metadata_cap: # Share fully
data = self.pre_share_model_transformed
m["params"] = data.numpy()
self.total_data += len(self.communication.encrypt(m["params"]))
if self.model.accumulated_changes is not None:
self.model.accumulated_changes = torch.zeros_like(
self.model.accumulated_changes
)
with torch.no_grad():
topk, indices = self.apply_fft()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = indices.tolist() # is slow
shared_params["alpha"] = self.alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["params"] = topk.numpy()
m["indices"] = indices.numpy().astype(np.int32)
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["alpha"])
)
Jeffrey Wigger
committed
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Parameters
----------
m : dict
json dict received
Returns
-------
state_dict
state_dict of received
"""
ret = dict()
if "send_partial" not in m:
params = m["params"]
params_tensor = torch.tensor(params)
ret["params"] = params_tensor
with torch.no_grad():
if not self.dict_ordered:
raise NotImplementedError
indices = m["indices"]
alpha = m["alpha"]
params = m["params"]
params_tensor = torch.tensor(params)
indices_tensor = torch.tensor(indices, dtype=torch.long)
ret["indices"] = indices_tensor
ret["params"] = params_tensor
Jeffrey Wigger
committed
return ret
Jeffrey Wigger
committed
def _averaging(self):
Jeffrey Wigger
committed
Averages the received model with the local model
Jeffrey Wigger
committed
total = None
weight_total = 0
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
pre_share_model = torch.cat(tensors_to_cat, dim=0)
flat_fft = self.change_transformer(pre_share_model)
Jeffrey Wigger
committed
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
Jeffrey Wigger
committed
)
data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
# use local data to complement
topkf = flat_fft.clone().detach()
topkf[indices] = params
else:
topkf = params
Jeffrey Wigger
committed
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
if total is None:
total = weight * topkf
else:
total += weight * topkf
Jeffrey Wigger
committed
# Metro-Hastings
total += (1 - weight_total) * flat_fft
reverse_total = fft.irfft(total)
Jeffrey Wigger
committed
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
Jeffrey Wigger
committed
start_index = end_index
self.model.load_state_dict(std_dict)