From 4b596b06889942eaffc1bc8053246e160c382f60 Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Sun, 12 Jun 2022 23:36:37 +0200 Subject: [PATCH] reddit dynamic --- eval/run_xtimes_reddit_rws.sh | 2 +- .../config_reddit_sharing_dynamicGraph.ini | 2 +- .../communication/TCPRandomWalk.py | 37 +- .../sharing/JwinsDynamicGraph.py | 511 ++++++++++++++++++ .../sharing/SharingDynamicGraph.py | 4 +- 5 files changed, 539 insertions(+), 17 deletions(-) create mode 100644 src/decentralizepy/sharing/JwinsDynamicGraph.py diff --git a/eval/run_xtimes_reddit_rws.sh b/eval/run_xtimes_reddit_rws.sh index 3726501..704055f 100755 --- a/eval/run_xtimes_reddit_rws.sh +++ b/eval/run_xtimes_reddit_rws.sh @@ -54,7 +54,7 @@ export PYTHONFAULTHANDLER=1 # Base configs for which the gird search is done # tests=("step_configs/config_reddit_sharing_topKdynamicGraph.ini") # tests=("step_configs/config_reddit_sharing_topKsharingasyncrw.ini" "step_configs/config_reddit_sharing_topKdpsgdrwasync.ini" "step_configs/config_reddit_sharing_topKdpsgdrw.ini") -tests=("step_configs/config_reddit_sharing_dpsgdrwasync0.ini") +tests=("step_configs/config_reddit_sharing_dynamicGraph.ini")# ("step_configs/config_reddit_sharing_dpsgdrwasync0.ini") # tests=("step_configs/config_reddit_sharing_dpsgdrw.ini" "step_configs/config_reddit_sharing_dpsgdrwasync.ini" "step_configs/config_reddit_sharing_sharingasyncrw.ini" "step_configs/config_reddit_sharing_sharingrw.ini") # Learning rates lr="1" diff --git a/eval/step_configs/config_reddit_sharing_dynamicGraph.ini b/eval/step_configs/config_reddit_sharing_dynamicGraph.ini index 44849cc..3caeabb 100644 --- a/eval/step_configs/config_reddit_sharing_dynamicGraph.ini +++ b/eval/step_configs/config_reddit_sharing_dynamicGraph.ini @@ -32,4 +32,4 @@ sampler = equi [SHARING] sharing_package = decentralizepy.sharing.SharingDynamicGraph sharing_class = SharingDynamicGraph -avg = True \ No newline at end of file +avg = False \ No newline at end of file diff --git a/src/decentralizepy/communication/TCPRandomWalk.py b/src/decentralizepy/communication/TCPRandomWalk.py index 6dee4db..10dcdfb 100644 --- a/src/decentralizepy/communication/TCPRandomWalk.py +++ b/src/decentralizepy/communication/TCPRandomWalk.py @@ -129,17 +129,23 @@ class TCPRandomWalkBase(Communication): logging.debug("in encrypt") if self.compress: logging.debug("in encrypt: compress") - if "indices" in data: - data["indices"] = self.compressor.compress(data["indices"]) + if type(data) == dict: + if "indices" in data: + data["indices"] = self.compressor.compress(data["indices"]) - assert "params" in data - data["params"] = self.compressor.compress_float(data["params"]) - data_len = len(pickle.dumps(data["params"])) - output = pickle.dumps(data) + if "params" in data: + data["params"] = self.compressor.compress_float(data["params"]) + data_len = len(pickle.dumps(data["params"])) + else: + data_len = 0 + output = pickle.dumps(data) + + # the compressed meta data gets only a few bytes smaller after pickling + self.total_meta.value += (len(output) - data_len) + self.total_data.value += data_len + else: + output = pickle.dumps(data) - # the compressed meta data gets only a few bytes smaller after pickling - self.total_meta.value += (len(output) - data_len) - self.total_data.value += data_len else: logging.debug("in encrypt: else") output = pickle.dumps(data) @@ -175,10 +181,15 @@ class TCPRandomWalkBase(Communication): if self.compress: logging.debug("in decrypt:comp") data = pickle.loads(data) - if "indices" in data: - data["indices"] = self.compressor.decompress(data["indices"]) - if "params" in data: - data["params"] = self.compressor.decompress_float(data["params"]) + if type(data) == dict: + if "indices" in data: + data["indices"] = self.compressor.decompress(data["indices"]) + if "params" in data: + try: + data["params"] = self.compressor.decompress_float(data["params"]) + except: + print(f"faled for {data}") + else: logging.debug("in decrypt:else") data = pickle.loads(data) diff --git a/src/decentralizepy/sharing/JwinsDynamicGraph.py b/src/decentralizepy/sharing/JwinsDynamicGraph.py new file mode 100644 index 0000000..c1d6292 --- /dev/null +++ b/src/decentralizepy/sharing/JwinsDynamicGraph.py @@ -0,0 +1,511 @@ +import json +import logging +import os +from pathlib import Path + +import numpy as np +import pywt +import torch +from decentralizepy.sharing.SharingDynamicGraph import SharingDynamicGraph +from decentralizepy.utils import conditional_value, identity + +def change_transformer_wavelet(x, wavelet, level): + """ + Transforms the model changes into wavelet frequency domain + + Parameters + ---------- + x : torch.Tensor + Model change in the space domain + wavelet : str + name of the wavelet to be used in gradient compression + level: int + name of the wavelet to be used in gradient compression + + Returns + ------- + x : torch.Tensor + Representation of the change int the wavelet domain + """ + coeff = pywt.wavedec(x, wavelet, level=level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + return torch.from_numpy(data.ravel()) + +class JwinsDynamicGraph(SharingDynamicGraph): + """ + This class implements the vanilla version of partial model sharing. + + """ + + def __init__( + self, + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + alpha=1.0, + dict_ordered=True, + save_shared=False, + metadata_cap=1.0, + wavelet="haar", + level=4, + change_based_selection=True, + save_accumulated="", + accumulation=False, + accumulate_averaging_changes=False, + rw_chance=0.1, + rw_length=4, + neighbor_bound=(3, 5), + change_topo_interval=10, + avg = False, + lower_bound=0.1, + metro_hastings=True, + ): + """ + Constructor + + Parameters + ---------- + rank : int + Local rank + machine_id : int + Global machine id + communication : decentralizepy.communication.Communication + Communication module used to send and receive messages + mapping : decentralizepy.mappings.Mapping + Mapping (rank, machine_id) -> uid + graph : decentralizepy.graphs.Graph + Graph reprensenting neighbors + model : decentralizepy.models.Model + Model to train + dataset : decentralizepy.datasets.Dataset + Dataset for sharing data. Not implemented yet! TODO + log_dir : str + Location to write shared_params (only writing for 2 procs per machine) + alpha : float + Percentage of model to share + dict_ordered : bool + Specifies if the python dict maintains the order of insertion + save_shared : bool + Specifies if the indices of shared parameters should be logged + metadata_cap : float + Share full model when self.alpha > metadata_cap + wavelet: str + name of the wavelet to be used in gradient compression + level: int + name of the wavelet to be used in gradient compression + change_based_selection : bool + use frequency change to select topk frequencies + save_accumulated : bool + True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation + the accumulated change is stored. + accumulation : bool + True if the the indices to share should be selected based on accumulated frequency change + save_accumulated : bool + True if accumulated weight change should be written to file. In case of accumulation the accumulated change + is stored. If a change_transformer is used then the transformed change is stored. + change_transformer : (x: Tensor) -> Tensor + A function that transforms the model change into other domains. Default: identity function + accumulate_averaging_changes: bool + True if the accumulation should account the model change due to averaging + """ + self.wavelet = wavelet + self.level = level + + super().__init__( + rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, neighbor_bound, change_topo_interval, avg + ) + self.alpha = alpha + self.dict_ordered = dict_ordered + self.save_shared = save_shared + self.metadata_cap = metadata_cap + self.accumulation = accumulation + self.save_accumulated = conditional_value(save_accumulated, "", False) + self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level) + self.accumulate_averaging_changes = accumulate_averaging_changes + + # getting the initial model + self.shapes = [] + self.lens = [] + with torch.no_grad(): + tensors_to_cat = [] + for _, v in self.model.state_dict().items(): + self.shapes.append(v.shape) + t = v.flatten() + self.lens.append(t.shape[0]) + tensors_to_cat.append(t) + self.init_model = torch.cat(tensors_to_cat, dim=0) + if self.accumulation: + self.model.accumulated_changes = torch.zeros_like( + self.change_transformer(self.init_model) + ) + self.prev = self.init_model + self.number_of_params = self.init_model.shape[0] + if self.save_accumulated: + self.model_change_path = os.path.join( + self.log_dir, "model_change/{}".format(self.rank) + ) + Path(self.model_change_path).mkdir(parents=True, exist_ok=True) + + self.model_val_path = os.path.join( + self.log_dir, "model_val/{}".format(self.rank) + ) + Path(self.model_val_path).mkdir(parents=True, exist_ok=True) + + # Only save for 2 procs: Save space + if self.save_shared and not (rank == 0 or rank == 1): + self.save_shared = False + + if self.save_shared: + self.folder_path = os.path.join( + self.log_dir, "shared_params/{}".format(self.rank) + ) + Path(self.folder_path).mkdir(parents=True, exist_ok=True) + + self.model.shared_parameters_counter = torch.zeros( + self.change_transformer(self.init_model).shape[0], dtype=torch.int32 + ) + + self.change_based_selection = change_based_selection + + # Do a dummy transform to get the shape and coefficents slices + coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + self.wt_shape = data.shape + self.coeff_slices = coeff_slices + + self.lower_bound = lower_bound + self.metro_hastings = metro_hastings + if self.lower_bound > 0: + self.start_lower_bounding_at = 1 / self.lower_bound + + def apply_wavelet(self): + """ + Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain + based on the undergone change during the current training step + + Returns + ------- + tuple + (a,b). a: selected wavelet coefficients, b: Their indices. + + """ + + data = self.pre_share_model_transformed + if self.change_based_selection: + diff = self.model.model_change + _, index = torch.topk( + diff.abs(), + round(self.alpha * len(diff)), + dim=0, + sorted=False, + ) + else: + _, index = torch.topk( + data.abs(), + round(self.alpha * len(data)), + dim=0, + sorted=False, + ) + ind, _ = torch.sort(index) + + if self.lower_bound == 0.0: + return data[ind].clone().detach(), ind + + if self.communication_round > self.start_lower_bounding_at: + # because the superclass increases it where it is inconvenient for this subclass + currently_shared = self.model.shared_parameters_counter.clone().detach() + currently_shared[ind] += 1 + ind_small = ( + currently_shared < self.communication_round * self.lower_bound + ).nonzero(as_tuple=True)[0] + ind_small_unique = np.setdiff1d( + ind_small.numpy(), ind.numpy(), assume_unique=True + ) + take_max = round(self.lower_bound * self.alpha * data.shape[0]) + logging.info( + "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max + ) + if take_max > ind_small_unique.shape[0]: + take_max = ind_small_unique.shape[0] + to_take = torch.rand(ind_small_unique.shape[0]) + _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False) + ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take] + logging.info("lower bounding: %i %i", len(ind), len(ind_bound)) + # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used + ind = torch.cat([ind, ind_bound]) + + index, _ = torch.sort(ind) + return data[index].clone().detach(), index + + def serialized_model(self): + """ + Convert model to json dict. self.alpha specifies the fraction of model to send. + + Returns + ------- + dict + Model converted to json dict + + """ + m = dict() + if self.alpha >= self.metadata_cap: # Share fully + data = self.pre_share_model_transformed + m["params"] = data.numpy() + if self.model.accumulated_changes is not None: + self.model.accumulated_changes = torch.zeros_like( + self.model.accumulated_changes + ) + return m + + with torch.no_grad(): + topk, self.G_topk = self.apply_wavelet() + self.model.shared_parameters_counter[self.G_topk] += 1 + self.model.rewind_accumulation(self.G_topk) + if self.save_shared: + shared_params = dict() + shared_params["order"] = list(self.model.state_dict().keys()) + shapes = dict() + for k, v in self.model.state_dict().items(): + shapes[k] = list(v.shape) + shared_params["shapes"] = shapes + + shared_params[self.communication_round] = self.G_topk.tolist() + + with open( + os.path.join( + self.folder_path, + "{}_shared_params.json".format(self.communication_round + 1), + ), + "w", + ) as of: + json.dump(shared_params, of) + + if not self.dict_ordered: + raise NotImplementedError + + m["alpha"] = self.alpha + + m["indices"] = self.G_topk.numpy().astype(np.int32) + + m["params"] = topk.numpy() + + m["send_partial"] = True + + assert len(m["indices"]) == len(m["params"]) + logging.info("Elements sending: {}".format(len(m["indices"]))) + + logging.info("Generated dictionary to send") + + logging.info("Converted dictionary to pickle") + + return m + + + def deserialized_model(self, m): + """ + Convert received dict to state_dict. + + Parameters + ---------- + m : dict + received dict + + Returns + ------- + state_dict + state_dict of received + + """ + ret = dict() + if "send_partial" not in m: + params = m["params"] + params_tensor = torch.tensor(params) + ret["params"] = params_tensor + return ret + + with torch.no_grad(): + if not self.dict_ordered: + raise NotImplementedError + + params_tensor = torch.tensor(m["params"]) + indices_tensor = torch.tensor(m["indices"], dtype=torch.long) + ret = dict() + ret["indices"] = indices_tensor + ret["params"] = params_tensor + ret["send_partial"] = True + return ret + + def _averaging(self): + """ + Averages the received model with the local model + + """ + with torch.no_grad(): + total = None + weight_total = 0 + wt_params = self.pre_share_model_transformed + if not self.metro_hastings: + weight_vector = torch.ones_like(wt_params) + datas = [] + batch = self._preprocessing_received_models() + for n, vals in batch.items(): + if len(vals) > 1: + # this should never happen + logging.info("GOT double message in dynamic graph!") + else: + degree, iteration, data = vals[0] + #degree, iteration, data = self.peer_deques[n].popleft() + logging.debug( + "Averaging model from neighbor {} of iteration {}".format( + n, iteration + ) + ) + #data = self.deserialized_model(data) + params = data["params"] + if "indices" in data: + indices = data["indices"] + if not self.metro_hastings: + weight_vector[indices] += 1 + topkwf = torch.zeros_like(wt_params) + topkwf[indices] = params + topkwf = topkwf.reshape(self.wt_shape) + datas.append(topkwf) + else: + # use local data to complement + topkwf = wt_params.clone().detach() + topkwf[indices] = params + topkwf = topkwf.reshape(self.wt_shape) + + else: + topkwf = params.reshape(self.wt_shape) + if not self.metro_hastings: + weight_vector += 1 + datas.append(topkwf) + + if self.metro_hastings: + weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings + weight_total += weight + if total is None: + total = weight * topkwf + else: + total += weight * topkwf + if not self.metro_hastings: + weight_vector = 1.0 / weight_vector + # speed up by exploiting sparsity + total = wt_params * weight_vector + for d in datas: + total += d * weight_vector + else: + # Metro-Hastings + total += (1 - weight_total) * wt_params + + avg_wf_params = pywt.array_to_coeffs( + total.numpy(), self.coeff_slices, output_format="wavedec" + ) + reverse_total = torch.from_numpy( + pywt.waverec(avg_wf_params, wavelet=self.wavelet) + ) + + start_index = 0 + std_dict = {} + for i, key in enumerate(self.model.state_dict()): + end_index = start_index + self.lens[i] + std_dict[key] = reverse_total[start_index:end_index].reshape( + self.shapes[i] + ) + start_index = end_index + + self.model.load_state_dict(std_dict) + + + def _pre_step(self): + """ + Called at the beginning of step. + + """ + logging.info("PartialModel _pre_step") + with torch.no_grad(): + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + self.pre_share_model = torch.cat(tensors_to_cat, dim=0) + # Would only need one of the transforms + self.pre_share_model_transformed = self.change_transformer( + self.pre_share_model + ) + change = self.change_transformer(self.pre_share_model - self.init_model) + if self.accumulation: + if not self.accumulate_averaging_changes: + # Need to accumulate in _pre_step as the accumulation gets rewind during the step + self.model.accumulated_changes += change + change = self.model.accumulated_changes.clone().detach() + else: + # For the legacy implementation, we will only rewind currently accumulated values + # and add the model change due to averaging in the end + change += self.model.accumulated_changes + # stores change of the model due to training, change due to averaging is not accounted + self.model.model_change = change + + def _post_step(self): + """ + Called at the end of step. + + """ + logging.info("PartialModel _post_step") + with torch.no_grad(): + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + post_share_model = torch.cat(tensors_to_cat, dim=0) + self.init_model = post_share_model + if self.accumulation: + if self.accumulate_averaging_changes: + self.model.accumulated_changes += self.change_transformer( + self.init_model - self.prev + ) + self.prev = self.init_model + self.model.model_change = None + if self.save_accumulated: + self.save_change() + + def save_vector(self, v, s): + """ + Saves the given vector to the file. + + Parameters + ---------- + v : torch.tensor + The torch tensor to write to file + s : str + Path to folder to write to + + """ + output_dict = dict() + output_dict["order"] = list(self.model.state_dict().keys()) + shapes = dict() + for k, v1 in self.model.state_dict().items(): + shapes[k] = list(v1.shape) + output_dict["shapes"] = shapes + + output_dict["tensor"] = v.tolist() + + with open( + os.path.join( + s, + "{}.json".format(self.communication_round + 1), + ), + "w", + ) as of: + json.dump(output_dict, of) + + def save_change(self): + """ + Saves the change and the gradient values for every iteration + + """ + self.save_vector(self.model.model_change, self.model_change_path) diff --git a/src/decentralizepy/sharing/SharingDynamicGraph.py b/src/decentralizepy/sharing/SharingDynamicGraph.py index 2564725..0c744dd 100644 --- a/src/decentralizepy/sharing/SharingDynamicGraph.py +++ b/src/decentralizepy/sharing/SharingDynamicGraph.py @@ -31,7 +31,7 @@ class SharingDynamicGraph(DPSGDRW): rw_length=6, neighbor_bound=(3, 5), change_topo_interval=10, - avg = True + avg = False ): """ Constructor @@ -70,7 +70,7 @@ class SharingDynamicGraph(DPSGDRW): rw_chance, rw_length, ) - self.avg = True + self.avg = avg self.neighbor_bound_lower = neighbor_bound[0] self.neighbor_bound_higher = neighbor_bound[1] self.change_topo_interval = change_topo_interval -- GitLab