Skip to content
Snippets Groups Projects
Commit f005b1e4 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

WaveletBound.py

parent d20cf02d
No related branches found
No related tags found
No related merge requests found
import json
import logging
import os
from pathlib import Path
from time import time
import numpy as np
import pywt
import torch
from decentralizepy.sharing.LowerBoundTopK import LowerBoundTopK
def change_transformer_wavelet(x, wavelet, level):
"""
Transforms the model changes into wavelet frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
wavelet : str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
Returns
-------
x : torch.Tensor
Representation of the change int the wavelet domain
"""
coeff = pywt.wavedec(x, wavelet, level=level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
return torch.from_numpy(data.ravel())
class WaveletBound(LowerBoundTopK):
"""
This class implements the wavelet version of model sharing
It is based on PartialModel.py
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
wavelet="haar",
level=4,
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes=False,
lower_bound=0.1,
metro_hastings=True,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
wavelet: str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
change_based_selection : bool
use frequency change to select topk frequencies
save_accumulated : bool
True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
self.wavelet = wavelet
self.level = level
super().__init__(
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
lower_bound,
metro_hastings,
alpha = alpha,
dict_ordered = dict_ordered,
save_shared = save_shared,
metadata_cap = metadata_cap,
accumulation = accumulation,
save_accumulated = save_accumulated,
change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level),
accumulate_averaging_changes = accumulate_averaging_changes,
)
self.change_based_selection = change_based_selection
# Do a dummy transform to get the shape and coefficents slices
coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.wt_shape = data.shape
self.coeff_slices = coeff_slices
def apply_wavelet(self):
"""
Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected wavelet coefficients, b: Their indices.
"""
logging.info("Returning wavelet compressed model weights")
data = self.pre_share_model_transformed
if self.change_based_selection:
diff = self.model.model_change
_, index = torch.topk(
diff.abs(),
round(self.alpha * len(diff)),
dim=0,
sorted=False,
)
else:
_, index = torch.topk(
data.abs(),
round(self.alpha * len(data)),
dim=0,
sorted=False,
)
index, _ = torch.sort(index)
return data[index], index
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
if self.lower_bound == 0.0:
return self.apply_wavelet()
data = self.pre_share_model_transformed
if self.change_based_selection:
diff = self.model.model_change
_, index = torch.topk(
diff.abs(),
round(self.alpha * len(diff)),
dim=0,
sorted=False,
)
else:
_, index = torch.topk(
data.abs(),
round(self.alpha * len(data)),
dim=0,
sorted=False,
)
ind, _ = torch.sort(index)
if self.communication_round > self.start_lower_bounding_at:
# because the superclass increases it where it is inconvenient for this subclass
currently_shared = self.model.shared_parameters_counter.clone().detach()
currently_shared[ind] += 1
ind_small = (
currently_shared < self.communication_round * self.lower_bound
).nonzero(as_tuple=True)[0]
ind_small_unique = np.setdiff1d(
ind_small.numpy(), ind.numpy(), assume_unique=True
)
take_max = round(self.lower_bound * self.alpha * data.shape[0])
logging.info(
"lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max
)
if take_max > ind_small_unique.shape[0]:
take_max = ind_small_unique.shape[0]
to_take = torch.rand(ind_small_unique.shape[0])
_, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False)
ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take]
logging.info("lower bounding: %i %i", len(ind), len(ind_bound))
# val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
ind = torch.cat([ind, ind_bound])
index, _ = torch.sort(ind)
return _, index
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
m = dict()
if self.alpha >= self.metadata_cap: # Share fully
data = self.pre_share_model_transformed
m["params"] = data.numpy()
if self.model.accumulated_changes is not None:
self.model.accumulated_changes = torch.zeros_like(
self.model.accumulated_changes
)
return m
with torch.no_grad():
topk, indices = self.apply_wavelet()
self.model.shared_parameters_counter[indices] += 1
self.model.rewind_accumulation(indices)
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = indices.tolist() # is slow
shared_params["alpha"] = self.alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["params"] = topk.numpy()
m["indices"] = indices.numpy().astype(np.int32)
m["send_partial"] = True
return m
def deserialized_model_avg(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
dict received
Returns
-------
state_dict
state_dict of received
"""
if "send_partial" not in m:
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
# could be made more efficent
T = torch.zeros_like(self.init_model)
index_tensor = torch.tensor(m["indices"], dtype=torch.long)
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor]))
return T, index_tensor
def deserialized_model(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
ret = dict()
if "send_partial" not in m:
params = m["params"]
params_tensor = torch.tensor(params)
ret["params"] = params_tensor
return ret
with torch.no_grad():
if not self.dict_ordered:
raise NotImplementedError
alpha = m["alpha"]
params_tensor = torch.tensor(m["params"])
indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
ret = dict()
ret["indices"] = indices_tensor
ret["params"] = params_tensor
ret["send_partial"] = True
return ret
def _averaging(self):
"""
Averages the received model with the local model
"""
with torch.no_grad():
total = None
weight_total = 0
wt_params = self.pre_share_model_transformed
if not self.metro_hastings:
weight_vector = torch.ones_like(wt_params)
datas = []
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
if not self.metro_hastings:
weight_vector[indices] += 1
topkwf = torch.zeros_like(wt_params)
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
datas.append(topkwf)
else:
# use local data to complement
topkwf = wt_params.clone().detach()
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
else:
topkwf = params.reshape(self.wt_shape)
if not self.metro_hastings:
weight_vector += 1
datas.append(topkwf)
if self.metro_hastings:
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
if not self.metro_hastings:
weight_vector = 1.0 / weight_vector
# speed up by exploiting sparsity
total = wt_params * weight_vector
for d in datas:
total += d * weight_vector
else:
# Metro-Hastings
total += (1 - weight_total) * wt_params
avg_wf_params = pywt.array_to_coeffs(
total.numpy(), self.coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment