import json import logging import os from pathlib import Path import numpy as np import pywt import torch from decentralizepy.sharing.SharingDynamicGraph import SharingDynamicGraph from decentralizepy.utils import conditional_value, identity def change_transformer_wavelet(x, wavelet, level): """ Transforms the model changes into wavelet frequency domain Parameters ---------- x : torch.Tensor Model change in the space domain wavelet : str name of the wavelet to be used in gradient compression level: int name of the wavelet to be used in gradient compression Returns ------- x : torch.Tensor Representation of the change int the wavelet domain """ coeff = pywt.wavedec(x, wavelet, level=level) data, coeff_slices = pywt.coeffs_to_array(coeff) return torch.from_numpy(data.ravel()) class JwinsDynamicGraph(SharingDynamicGraph): """ This class implements the vanilla version of partial model sharing. """ def __init__( self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha=1.0, dict_ordered=True, save_shared=False, metadata_cap=1.0, wavelet="haar", level=4, change_based_selection=True, save_accumulated="", accumulation=False, accumulate_averaging_changes=False, rw_chance=1, rw_length=6, neighbor_bound=(3, 5), change_topo_interval=10, avg = False, lower_bound=0.1, metro_hastings=True, ): """ Constructor Parameters ---------- rank : int Local rank machine_id : int Global machine id communication : decentralizepy.communication.Communication Communication module used to send and receive messages mapping : decentralizepy.mappings.Mapping Mapping (rank, machine_id) -> uid graph : decentralizepy.graphs.Graph Graph reprensenting neighbors model : decentralizepy.models.Model Model to train dataset : decentralizepy.datasets.Dataset Dataset for sharing data. Not implemented yet! TODO log_dir : str Location to write shared_params (only writing for 2 procs per machine) alpha : float Percentage of model to share dict_ordered : bool Specifies if the python dict maintains the order of insertion save_shared : bool Specifies if the indices of shared parameters should be logged metadata_cap : float Share full model when self.alpha > metadata_cap wavelet: str name of the wavelet to be used in gradient compression level: int name of the wavelet to be used in gradient compression change_based_selection : bool use frequency change to select topk frequencies save_accumulated : bool True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation the accumulated change is stored. accumulation : bool True if the the indices to share should be selected based on accumulated frequency change save_accumulated : bool True if accumulated weight change should be written to file. In case of accumulation the accumulated change is stored. If a change_transformer is used then the transformed change is stored. change_transformer : (x: Tensor) -> Tensor A function that transforms the model change into other domains. Default: identity function accumulate_averaging_changes: bool True if the accumulation should account the model change due to averaging """ self.wavelet = wavelet self.level = level super().__init__( rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, neighbor_bound, change_topo_interval, avg ) self.alpha = alpha self.dict_ordered = dict_ordered self.save_shared = save_shared self.metadata_cap = metadata_cap self.accumulation = accumulation self.save_accumulated = conditional_value(save_accumulated, "", False) self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level) self.accumulate_averaging_changes = accumulate_averaging_changes # getting the initial model self.shapes = [] self.lens = [] with torch.no_grad(): tensors_to_cat = [] for _, v in self.model.state_dict().items(): self.shapes.append(v.shape) t = v.flatten() self.lens.append(t.shape[0]) tensors_to_cat.append(t) self.init_model = torch.cat(tensors_to_cat, dim=0) if self.accumulation: self.model.accumulated_changes = torch.zeros_like( self.change_transformer(self.init_model) ) self.prev = self.init_model self.number_of_params = self.init_model.shape[0] if self.save_accumulated: self.model_change_path = os.path.join( self.log_dir, "model_change/{}".format(self.rank) ) Path(self.model_change_path).mkdir(parents=True, exist_ok=True) self.model_val_path = os.path.join( self.log_dir, "model_val/{}".format(self.rank) ) Path(self.model_val_path).mkdir(parents=True, exist_ok=True) # Only save for 2 procs: Save space if self.save_shared and not (rank == 0 or rank == 1): self.save_shared = False if self.save_shared: self.folder_path = os.path.join( self.log_dir, "shared_params/{}".format(self.rank) ) Path(self.folder_path).mkdir(parents=True, exist_ok=True) self.model.shared_parameters_counter = torch.zeros( self.change_transformer(self.init_model).shape[0], dtype=torch.int32 ) self.change_based_selection = change_based_selection # Do a dummy transform to get the shape and coefficents slices coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level) data, coeff_slices = pywt.coeffs_to_array(coeff) self.wt_shape = data.shape self.coeff_slices = coeff_slices self.lower_bound = lower_bound self.metro_hastings = metro_hastings if self.lower_bound > 0: self.start_lower_bounding_at = 1 / self.lower_bound def apply_wavelet(self): """ Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain based on the undergone change during the current training step Returns ------- tuple (a,b). a: selected wavelet coefficients, b: Their indices. """ data = self.pre_share_model_transformed if self.change_based_selection: diff = self.model.model_change _, index = torch.topk( diff.abs(), round(self.alpha * len(diff)), dim=0, sorted=False, ) else: _, index = torch.topk( data.abs(), round(self.alpha * len(data)), dim=0, sorted=False, ) ind, _ = torch.sort(index) if self.lower_bound == 0.0: return data[ind].clone().detach(), ind if self.communication_round > self.start_lower_bounding_at: # because the superclass increases it where it is inconvenient for this subclass currently_shared = self.model.shared_parameters_counter.clone().detach() currently_shared[ind] += 1 ind_small = ( currently_shared < self.communication_round * self.lower_bound ).nonzero(as_tuple=True)[0] ind_small_unique = np.setdiff1d( ind_small.numpy(), ind.numpy(), assume_unique=True ) take_max = round(self.lower_bound * self.alpha * data.shape[0]) logging.info( "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max ) if take_max > ind_small_unique.shape[0]: take_max = ind_small_unique.shape[0] to_take = torch.rand(ind_small_unique.shape[0]) _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False) ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take] logging.info("lower bounding: %i %i", len(ind), len(ind_bound)) # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used ind = torch.cat([ind, ind_bound]) index, _ = torch.sort(ind) return data[index].clone().detach(), index def serialized_model(self): """ Convert model to json dict. self.alpha specifies the fraction of model to send. Returns ------- dict Model converted to json dict """ m = dict() if self.alpha >= self.metadata_cap: # Share fully data = self.pre_share_model_transformed m["params"] = data.numpy() if self.model.accumulated_changes is not None: self.model.accumulated_changes = torch.zeros_like( self.model.accumulated_changes ) return m with torch.no_grad(): topk, self.G_topk = self.apply_wavelet() self.model.shared_parameters_counter[self.G_topk] += 1 self.model.rewind_accumulation(self.G_topk) if self.save_shared: shared_params = dict() shared_params["order"] = list(self.model.state_dict().keys()) shapes = dict() for k, v in self.model.state_dict().items(): shapes[k] = list(v.shape) shared_params["shapes"] = shapes shared_params[self.communication_round] = self.G_topk.tolist() with open( os.path.join( self.folder_path, "{}_shared_params.json".format(self.communication_round + 1), ), "w", ) as of: json.dump(shared_params, of) if not self.dict_ordered: raise NotImplementedError m["alpha"] = self.alpha m["indices"] = self.G_topk.numpy().astype(np.int32) m["params"] = topk.numpy() m["send_partial"] = True assert len(m["indices"]) == len(m["params"]) logging.info("Elements sending: {}".format(len(m["indices"]))) logging.info("Generated dictionary to send") logging.info("Converted dictionary to pickle") return m def deserialized_model(self, m): """ Convert received dict to state_dict. Parameters ---------- m : dict received dict Returns ------- state_dict state_dict of received """ ret = dict() if "send_partial" not in m: params = m["params"] params_tensor = torch.tensor(params) ret["params"] = params_tensor return ret with torch.no_grad(): if not self.dict_ordered: raise NotImplementedError params_tensor = torch.tensor(m["params"]) indices_tensor = torch.tensor(m["indices"], dtype=torch.long) ret = dict() ret["indices"] = indices_tensor ret["params"] = params_tensor ret["send_partial"] = True return ret def _averaging(self): """ Averages the received model with the local model """ with torch.no_grad(): total = None weight_total = 0 wt_params = self.pre_share_model_transformed if not self.metro_hastings: weight_vector = torch.ones_like(wt_params) datas = [] batch = self._preprocessing_received_models() for n, vals in batch.items(): if len(vals) > 1: # this should never happen logging.info("GOT double message in dynamic graph!") else: degree, iteration, data = vals[0] #degree, iteration, data = self.peer_deques[n].popleft() logging.debug( "Averaging model from neighbor {} of iteration {}".format( n, iteration ) ) #data = self.deserialized_model(data) params = data["params"] if "indices" in data: indices = data["indices"] if not self.metro_hastings: weight_vector[indices] += 1 topkwf = torch.zeros_like(wt_params) topkwf[indices] = params topkwf = topkwf.reshape(self.wt_shape) datas.append(topkwf) else: # use local data to complement topkwf = wt_params.clone().detach() topkwf[indices] = params topkwf = topkwf.reshape(self.wt_shape) else: topkwf = params.reshape(self.wt_shape) if not self.metro_hastings: weight_vector += 1 datas.append(topkwf) if self.metro_hastings: weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings weight_total += weight if total is None: total = weight * topkwf else: total += weight * topkwf if not self.metro_hastings: weight_vector = 1.0 / weight_vector # speed up by exploiting sparsity total = wt_params * weight_vector for d in datas: total += d * weight_vector else: # Metro-Hastings total += (1 - weight_total) * wt_params avg_wf_params = pywt.array_to_coeffs( total.numpy(), self.coeff_slices, output_format="wavedec" ) reverse_total = torch.from_numpy( pywt.waverec(avg_wf_params, wavelet=self.wavelet) ) start_index = 0 std_dict = {} for i, key in enumerate(self.model.state_dict()): end_index = start_index + self.lens[i] std_dict[key] = reverse_total[start_index:end_index].reshape( self.shapes[i] ) start_index = end_index self.model.load_state_dict(std_dict) def _pre_step(self): """ Called at the beginning of step. """ logging.info("PartialModel _pre_step") with torch.no_grad(): tensors_to_cat = [ v.data.flatten() for _, v in self.model.state_dict().items() ] self.pre_share_model = torch.cat(tensors_to_cat, dim=0) # Would only need one of the transforms self.pre_share_model_transformed = self.change_transformer( self.pre_share_model ) change = self.change_transformer(self.pre_share_model - self.init_model) if self.accumulation: if not self.accumulate_averaging_changes: # Need to accumulate in _pre_step as the accumulation gets rewind during the step self.model.accumulated_changes += change change = self.model.accumulated_changes.clone().detach() else: # For the legacy implementation, we will only rewind currently accumulated values # and add the model change due to averaging in the end change += self.model.accumulated_changes # stores change of the model due to training, change due to averaging is not accounted self.model.model_change = change def _post_step(self): """ Called at the end of step. """ logging.info("PartialModel _post_step") with torch.no_grad(): tensors_to_cat = [ v.data.flatten() for _, v in self.model.state_dict().items() ] post_share_model = torch.cat(tensors_to_cat, dim=0) self.init_model = post_share_model if self.accumulation: if self.accumulate_averaging_changes: self.model.accumulated_changes += self.change_transformer( self.init_model - self.prev ) self.prev = self.init_model self.model.model_change = None if self.save_accumulated: self.save_change() def save_vector(self, v, s): """ Saves the given vector to the file. Parameters ---------- v : torch.tensor The torch tensor to write to file s : str Path to folder to write to """ output_dict = dict() output_dict["order"] = list(self.model.state_dict().keys()) shapes = dict() for k, v1 in self.model.state_dict().items(): shapes[k] = list(v1.shape) output_dict["shapes"] = shapes output_dict["tensor"] = v.tolist() with open( os.path.join( s, "{}.json".format(self.communication_round + 1), ), "w", ) as of: json.dump(output_dict, of) def save_change(self): """ Saves the change and the gradient values for every iteration """ self.save_vector(self.model.model_change, self.model_change_path)