From 3be414ab24e9486b97afee7e120be9efb382aafd Mon Sep 17 00:00:00 2001 From: Jeffrey Wigger <jeffrey.wigger@epfl.ch> Date: Wed, 15 Jun 2022 10:31:19 +0200 Subject: [PATCH] random walk jwins --- eval/run_xtimes_reddit_jwins.sh | 6 +- eval/run_xtimes_reddit_rws.sh | 4 +- .../config_reddit_jwins+_local.ini | 45 ++ .../config_reddit_sharing_jwinsasync.ini | 45 ++ .../config_reddit_sharing_jwinsasync30.ini | 45 ++ src/decentralizepy/sharing/JwinsDPSGDAsync.py | 549 ++++++++++++++++++ 6 files changed, 689 insertions(+), 5 deletions(-) create mode 100644 eval/step_configs/config_reddit_jwins+_local.ini create mode 100644 eval/step_configs/config_reddit_sharing_jwinsasync.ini create mode 100644 eval/step_configs/config_reddit_sharing_jwinsasync30.ini create mode 100644 src/decentralizepy/sharing/JwinsDPSGDAsync.py diff --git a/eval/run_xtimes_reddit_jwins.sh b/eval/run_xtimes_reddit_jwins.sh index 84bcf8e..72e98fc 100755 --- a/eval/run_xtimes_reddit_jwins.sh +++ b/eval/run_xtimes_reddit_jwins.sh @@ -55,9 +55,9 @@ export PYTHONFAULTHANDLER=1 # "step_configs/config_reddit_sharing.ini" tests=("step_configs/config_reddit_sharing.ini" "step_configs/config_reddit_jwins+.ini" "step_configs/config_reddit_topkacc.ini" "step_configs/config_reddit_subsampling.ini") # Learning rates -lr="1" +lr="0.1" # Batch size -batchsize="16" +batchsize="8" # The number of communication rounds per global epoch comm_rounds_per_global_epoch="10" procs=`expr $procs_per_machine \* $machines` @@ -106,7 +106,7 @@ do $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize $python_bin/crudini --set $config_file DATASET random_seed $seed - $env_python $eval_file -ro 0 -cte 1 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level + $env_python $eval_file -ro 0 -cte 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level echo $i is done sleep 300 echo end of sleep diff --git a/eval/run_xtimes_reddit_rws.sh b/eval/run_xtimes_reddit_rws.sh index 7b75e8e..57594e3 100755 --- a/eval/run_xtimes_reddit_rws.sh +++ b/eval/run_xtimes_reddit_rws.sh @@ -42,7 +42,7 @@ graph=96_regular.edges config_file=~/tmp/config.ini procs_per_machine=16 machines=6 -global_epochs=80 +global_epochs=160 eval_file=testing.py log_level=DEBUG @@ -54,7 +54,7 @@ export PYTHONFAULTHANDLER=1 # Base configs for which the gird search is done # tests=("step_configs/config_reddit_sharing_topKdynamicGraph.ini") # tests=("step_configs/config_reddit_sharing_topKsharingasyncrw.ini" "step_configs/config_reddit_sharing_topKdpsgdrwasync.ini" "step_configs/config_reddit_sharing_topKdpsgdrw.ini") -tests=("step_configs/config_reddit_sharing_dynamicGraphJwins30.ini") # ("step_configs/config_reddit_sharing_dpsgdrwasync0.ini") +tests=("step_configs/config_reddit_sharing_jwinsasync.ini" "step_configs/config_reddit_sharing_jwinsasync30.ini") # ("step_configs/config_reddit_sharing_dpsgdrwasync0.ini") # tests=("step_configs/config_reddit_sharing_dpsgdrw.ini" "step_configs/config_reddit_sharing_dpsgdrwasync.ini" "step_configs/config_reddit_sharing_sharingasyncrw.ini" "step_configs/config_reddit_sharing_sharingrw.ini") # Learning rates lr="1" diff --git a/eval/step_configs/config_reddit_jwins+_local.ini b/eval/step_configs/config_reddit_jwins+_local.ini new file mode 100644 index 0000000..6411cef --- /dev/null +++ b/eval/step_configs/config_reddit_jwins+_local.ini @@ -0,0 +1,45 @@ +[DATASET] +dataset_package = decentralizepy.datasets.Reddit +dataset_class = Reddit +random_seed = 97 +model_class = RNN +train_dir = /home/jeffrey/Downloads/reddit/per_user_data/train +test_dir = /home/jeffrey/Downloads/reddit/new_small_data/test +; python list of fractions below +sizes = + +[OPTIMIZER_PARAMS] +optimizer_package = torch.optim +optimizer_class = SGD +lr = 0.001 + +[TRAIN_PARAMS] +training_package = decentralizepy.training.Training +training_class = Training +rounds = 47 +full_epochs = False +batch_size = 16 +shuffle = True +loss_package = torch.nn +loss_class = CrossEntropyLoss + +[COMMUNICATION] +comm_package = decentralizepy.communication.TCPRandomWalk +comm_class = TCPRandomWalk +addresses_filepath = ip_addr_6Machines.json +compression_package = decentralizepy.compression.Eliaszfplossy1 +compression_class = Eliaszfplossy1 +compress = True +sampler = equi + +[SHARING] +sharing_package = decentralizepy.sharing.JwinsDPSGDAsync +sharing_class = JwinsDPSGDAsync +alpha=0.0833 +lower_bound=0.2 +metro_hastings=False +change_based_selection = True +wavelet=sym2 +level= None +accumulation = True +accumulate_averaging_changes = True \ No newline at end of file diff --git a/eval/step_configs/config_reddit_sharing_jwinsasync.ini b/eval/step_configs/config_reddit_sharing_jwinsasync.ini new file mode 100644 index 0000000..4a09852 --- /dev/null +++ b/eval/step_configs/config_reddit_sharing_jwinsasync.ini @@ -0,0 +1,45 @@ +[DATASET] +dataset_package = decentralizepy.datasets.Reddit +dataset_class = Reddit +random_seed = 97 +model_class = RNN +train_dir = /mnt/nfs/shared/leaf/data/reddit_new/per_user_data/train +test_dir = /mnt/nfs/shared/leaf/data/reddit_new/new_small_data/test +; python list of fractions below +sizes = + +[OPTIMIZER_PARAMS] +optimizer_package = torch.optim +optimizer_class = SGD +lr = 0.001 + +[TRAIN_PARAMS] +training_package = decentralizepy.training.Training +training_class = Training +rounds = 47 +full_epochs = False +batch_size = 16 +shuffle = True +loss_package = torch.nn +loss_class = CrossEntropyLoss + +[COMMUNICATION] +comm_package = decentralizepy.communication.TCPRandomWalk +comm_class = TCPRandomWalk +addresses_filepath = ip_addr_6Machines.json +compression_package = decentralizepy.compression.Eliaszfplossy1 +compression_class = Eliaszfplossy1 +compress = True +sampler = equi_check_history + +[SHARING] +sharing_package = decentralizepy.sharing.JwinsDPSGDAsync +sharing_class = JwinsDPSGDAsync +alpha=0.0833 +lower_bound=0.2 +metro_hastings=False +change_based_selection = True +wavelet=sym2 +level= None +accumulation = True +accumulate_averaging_changes = True diff --git a/eval/step_configs/config_reddit_sharing_jwinsasync30.ini b/eval/step_configs/config_reddit_sharing_jwinsasync30.ini new file mode 100644 index 0000000..e006a57 --- /dev/null +++ b/eval/step_configs/config_reddit_sharing_jwinsasync30.ini @@ -0,0 +1,45 @@ +[DATASET] +dataset_package = decentralizepy.datasets.Reddit +dataset_class = Reddit +random_seed = 97 +model_class = RNN +train_dir = /mnt/nfs/shared/leaf/data/reddit_new/per_user_data/train +test_dir = /mnt/nfs/shared/leaf/data/reddit_new/new_small_data/test +; python list of fractions below +sizes = + +[OPTIMIZER_PARAMS] +optimizer_package = torch.optim +optimizer_class = SGD +lr = 0.001 + +[TRAIN_PARAMS] +training_package = decentralizepy.training.Training +training_class = Training +rounds = 47 +full_epochs = False +batch_size = 16 +shuffle = True +loss_package = torch.nn +loss_class = CrossEntropyLoss + +[COMMUNICATION] +comm_package = decentralizepy.communication.TCPRandomWalk +comm_class = TCPRandomWalk +addresses_filepath = ip_addr_6Machines.json +compression_package = decentralizepy.compression.Eliaszfplossy1 +compression_class = Eliaszfplossy1 +compress = True +sampler = equi_check_history + +[SHARING] +sharing_package = decentralizepy.sharing.JwinsDPSGDAsync +sharing_class = JwinsDPSGDAsync +alpha=0.25 +lower_bound=0.2 +metro_hastings=False +change_based_selection = True +wavelet=sym2 +level= None +accumulation = True +accumulate_averaging_changes = True diff --git a/src/decentralizepy/sharing/JwinsDPSGDAsync.py b/src/decentralizepy/sharing/JwinsDPSGDAsync.py new file mode 100644 index 0000000..b0e0552 --- /dev/null +++ b/src/decentralizepy/sharing/JwinsDPSGDAsync.py @@ -0,0 +1,549 @@ +import json +import logging +import os +from pathlib import Path + +import numpy as np +import pywt +import torch +from decentralizepy.sharing.DPSGDRWAsync import DPSGDRWAsync +from decentralizepy.utils import conditional_value, identity + +def change_transformer_wavelet(x, wavelet, level): + """ + Transforms the model changes into wavelet frequency domain + + Parameters + ---------- + x : torch.Tensor + Model change in the space domain + wavelet : str + name of the wavelet to be used in gradient compression + level: int + name of the wavelet to be used in gradient compression + + Returns + ------- + x : torch.Tensor + Representation of the change int the wavelet domain + """ + coeff = pywt.wavedec(x, wavelet, level=level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + return torch.from_numpy(data.ravel()) + +class JwinsDPSGDAsync(DPSGDRWAsync): + """ + This class implements the vanilla version of partial model sharing. + + """ + + def __init__( + self, + rank, + machine_id, + communication, + mapping, + graph, + model, + dataset, + log_dir, + alpha=1.0, + dict_ordered=True, + save_shared=False, + metadata_cap=1.0, + wavelet="haar", + level=4, + change_based_selection=True, + save_accumulated="", + accumulation=False, + accumulate_averaging_changes=False, + rw_chance=1, + rw_length=4, + comm_interval=0.5, + min_interval=0.001, + max_lag=2, + lower_bound = 0.1, + metro_hastings=True, + ): + """ + Constructor + + Parameters + ---------- + rank : int + Local rank + machine_id : int + Global machine id + communication : decentralizepy.communication.Communication + Communication module used to send and receive messages + mapping : decentralizepy.mappings.Mapping + Mapping (rank, machine_id) -> uid + graph : decentralizepy.graphs.Graph + Graph reprensenting neighbors + model : decentralizepy.models.Model + Model to train + dataset : decentralizepy.datasets.Dataset + Dataset for sharing data. Not implemented yet! TODO + log_dir : str + Location to write shared_params (only writing for 2 procs per machine) + alpha : float + Percentage of model to share + dict_ordered : bool + Specifies if the python dict maintains the order of insertion + save_shared : bool + Specifies if the indices of shared parameters should be logged + metadata_cap : float + Share full model when self.alpha > metadata_cap + wavelet: str + name of the wavelet to be used in gradient compression + level: int + name of the wavelet to be used in gradient compression + change_based_selection : bool + use frequency change to select topk frequencies + save_accumulated : bool + True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation + the accumulated change is stored. + accumulation : bool + True if the the indices to share should be selected based on accumulated frequency change + save_accumulated : bool + True if accumulated weight change should be written to file. In case of accumulation the accumulated change + is stored. If a change_transformer is used then the transformed change is stored. + change_transformer : (x: Tensor) -> Tensor + A function that transforms the model change into other domains. Default: identity function + accumulate_averaging_changes: bool + True if the accumulation should account the model change due to averaging + """ + self.wavelet = wavelet + self.level = level + + super().__init__( + rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, comm_interval, min_interval, max_lag + ) + self.alpha = alpha + self.dict_ordered = dict_ordered + self.save_shared = save_shared + self.metadata_cap = metadata_cap + self.accumulation = accumulation + self.save_accumulated = conditional_value(save_accumulated, "", False) + self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level) + self.accumulate_averaging_changes = accumulate_averaging_changes + + # getting the initial model + self.shapes = [] + self.lens = [] + with torch.no_grad(): + tensors_to_cat = [] + for _, v in self.model.state_dict().items(): + self.shapes.append(v.shape) + t = v.flatten() + self.lens.append(t.shape[0]) + tensors_to_cat.append(t) + self.init_model = torch.cat(tensors_to_cat, dim=0) + if self.accumulation: + self.model.accumulated_changes = torch.zeros_like( + self.change_transformer(self.init_model) + ) + self.prev = self.init_model + self.number_of_params = self.init_model.shape[0] + if self.save_accumulated: + self.model_change_path = os.path.join( + self.log_dir, "model_change/{}".format(self.rank) + ) + Path(self.model_change_path).mkdir(parents=True, exist_ok=True) + + self.model_val_path = os.path.join( + self.log_dir, "model_val/{}".format(self.rank) + ) + Path(self.model_val_path).mkdir(parents=True, exist_ok=True) + + # Only save for 2 procs: Save space + if self.save_shared and not (rank == 0 or rank == 1): + self.save_shared = False + + if self.save_shared: + self.folder_path = os.path.join( + self.log_dir, "shared_params/{}".format(self.rank) + ) + Path(self.folder_path).mkdir(parents=True, exist_ok=True) + + self.model.shared_parameters_counter = torch.zeros( + self.change_transformer(self.init_model).shape[0], dtype=torch.int32 + ) + + self.change_based_selection = change_based_selection + + # Do a dummy transform to get the shape and coefficents slices + coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level) + data, coeff_slices = pywt.coeffs_to_array(coeff) + self.wt_shape = data.shape + self.coeff_slices = coeff_slices + + self.lower_bound = lower_bound + self.metro_hastings = metro_hastings + if self.lower_bound > 0: + self.start_lower_bounding_at = 1 / self.lower_bound + + def apply_wavelet(self): + """ + Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain + based on the undergone change during the current training step + + Returns + ------- + tuple + (a,b). a: selected wavelet coefficients, b: Their indices. + + """ + + data = self.pre_share_model_transformed + if self.change_based_selection: + diff = self.model.model_change + _, index = torch.topk( + diff.abs(), + round(self.alpha * len(diff)), + dim=0, + sorted=False, + ) + else: + _, index = torch.topk( + data.abs(), + round(self.alpha * len(data)), + dim=0, + sorted=False, + ) + ind, _ = torch.sort(index) + + if self.lower_bound == 0.0: + return data[ind].clone().detach(), ind + + if self.communication_round > self.start_lower_bounding_at: + # because the superclass increases it where it is inconvenient for this subclass + currently_shared = self.model.shared_parameters_counter.clone().detach() + currently_shared[ind] += 1 + ind_small = ( + currently_shared < self.communication_round * self.lower_bound + ).nonzero(as_tuple=True)[0] + ind_small_unique = np.setdiff1d( + ind_small.numpy(), ind.numpy(), assume_unique=True + ) + take_max = round(self.lower_bound * self.alpha * data.shape[0]) + logging.info( + "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max + ) + if take_max > ind_small_unique.shape[0]: + take_max = ind_small_unique.shape[0] + to_take = torch.rand(ind_small_unique.shape[0]) + _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False) + ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take] + logging.info("lower bounding: %i %i", len(ind), len(ind_bound)) + # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used + ind = torch.cat([ind, ind_bound]) + + index, _ = torch.sort(ind) + return data[index].clone().detach(), index + + def serialized_model(self): + """ + Convert model to json dict. self.alpha specifies the fraction of model to send. + + Returns + ------- + dict + Model converted to json dict + + """ + m = dict() + if self.alpha >= self.metadata_cap: # Share fully + data = self.pre_share_model_transformed + m["params"] = data.numpy() + if self.model.accumulated_changes is not None: + self.model.accumulated_changes = torch.zeros_like( + self.model.accumulated_changes + ) + return m + + with torch.no_grad(): + topk, self.G_topk = self.apply_wavelet() + self.model.shared_parameters_counter[self.G_topk] += 1 + self.model.rewind_accumulation(self.G_topk) + if self.save_shared: + shared_params = dict() + shared_params["order"] = list(self.model.state_dict().keys()) + shapes = dict() + for k, v in self.model.state_dict().items(): + shapes[k] = list(v.shape) + shared_params["shapes"] = shapes + + shared_params[self.communication_round] = self.G_topk.tolist() + + with open( + os.path.join( + self.folder_path, + "{}_shared_params.json".format(self.communication_round + 1), + ), + "w", + ) as of: + json.dump(shared_params, of) + + if not self.dict_ordered: + raise NotImplementedError + + m["alpha"] = self.alpha + + m["indices"] = self.G_topk.numpy().astype(np.int32) + + m["params"] = topk.numpy() + + m["send_partial"] = True + + assert len(m["indices"]) == len(m["params"]) + logging.info("Elements sending: {}".format(len(m["indices"]))) + + logging.info("Generated dictionary to send") + + logging.info("Converted dictionary to pickle") + + return m + + + def deserialized_model(self, m): + """ + Convert received dict to state_dict. + + Parameters + ---------- + m : dict + received dict + + Returns + ------- + state_dict + state_dict of received + + """ + ret = dict() + if "send_partial" not in m: + params = m["params"] + params_tensor = torch.tensor(params) + ret["params"] = params_tensor + return ret + + with torch.no_grad(): + if not self.dict_ordered: + raise NotImplementedError + + params_tensor = torch.tensor(m["params"]) + indices_tensor = torch.tensor(m["indices"], dtype=torch.long) + ret = dict() + ret["indices"] = indices_tensor + ret["params"] = params_tensor + ret["send_partial"] = True + return ret + + def _averaging(self): + """ + Averages the received model with the local model + + """ + with torch.no_grad(): + total = None + weight_total = 0 + wt_params = self.pre_share_model_transformed + if not self.metro_hastings: + weight_vector = torch.ones_like(wt_params) + datas = [] + batch = self._preprocessing_received_models() + for n, vals in batch.items(): + if len(vals) > 1: + data = None + degree = 0 + # this should no longer happen, unless we get two rw from the same originator + logging.info("averaging double messages for %i", n) + for val in vals: + degree_sub, iteration, data_sub = val + if data is None: + data = data_sub + degree = degree + else: + for key, weight_val in data_sub.items(): + data[key] += weight_val + degree = max(degree, degree_sub) + for key, weight_val in data.items(): + data[key] /= len(vals) + else: + degree, iteration, data = vals[0] + #degree, iteration, data = self.peer_deques[n].popleft() + logging.debug( + "Averaging model from neighbor {} of iteration {}".format( + n, iteration + ) + ) + #data = self.deserialized_model(data) + params = data["params"] + if "indices" in data: + indices = data["indices"] + if not self.metro_hastings: + weight_vector[indices] += 1 + topkwf = torch.zeros_like(wt_params) + topkwf[indices] = params + topkwf = topkwf.reshape(self.wt_shape) + datas.append(topkwf) + else: + # use local data to complement + topkwf = wt_params.clone().detach() + topkwf[indices] = params + topkwf = topkwf.reshape(self.wt_shape) + + else: + topkwf = params.reshape(self.wt_shape) + if not self.metro_hastings: + weight_vector += 1 + datas.append(topkwf) + + if self.metro_hastings: + weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings + weight_total += weight + if total is None: + total = weight * topkwf + else: + total += weight * topkwf + if not self.metro_hastings: + weight_vector = 1.0 / weight_vector + # speed up by exploiting sparsity + total = wt_params * weight_vector + for d in datas: + total += d * weight_vector + else: + # Metro-Hastings + total += (1 - weight_total) * wt_params + + avg_wf_params = pywt.array_to_coeffs( + total.numpy(), self.coeff_slices, output_format="wavedec" + ) + reverse_total = torch.from_numpy( + pywt.waverec(avg_wf_params, wavelet=self.wavelet) + ) + + start_index = 0 + std_dict = {} + for i, key in enumerate(self.model.state_dict()): + end_index = start_index + self.lens[i] + std_dict[key] = reverse_total[start_index:end_index].reshape( + self.shapes[i] + ) + start_index = end_index + + self.model.load_state_dict(std_dict) + to_cat = [] + for _, v in self.model.state_dict().items(): + vf = v.clone().detach().flatten() + to_cat.append(vf) + self.init_model = torch.cat(to_cat) + self._transformed = self.change_transformer(self.init_model) + + + def _pre_step(self): + """ + Called at the beginning of step. + + """ + logging.info("PartialModel _pre_step") + with torch.no_grad(): + tensors_to_cat = [ + v.data.flatten() for _, v in self.model.state_dict().items() + ] + self.pre_share_model = torch.cat(tensors_to_cat, dim=0) + # Would only need one of the transforms + self.pre_share_model_transformed = self.change_transformer( + self.pre_share_model + ) + + + def _post_step(self): + """ + Called at the end of step. + + """ + change = self.change_transformer(self.pre_share_model - self.init_model) # self.init_model is set in _averaging + if self.accumulation: + # Need to accumulate in _pre_step as the accumulation gets rewind during the step + self.model.accumulated_changes += change + change = self.model.accumulated_changes.clone().detach() + # stores change of the model due to training, change due to averaging is not accounted + self.model.model_change = change + + + def _send_rw(self): + def send(): + # will have to send the data twice to make the code simpler (for the beginning) + if self.alpha >= self.metadata_cap: + rw_data = { + "params": self.init_model.numpy(), + "rw": True, + "degree": self.number_of_neighbors, + "iteration": self.communication_round, + "visited": [self.uid], + "fuel": self.rw_length - 1, + } + else: + rw_data = { + "params": self._transformed[self.G_topk].numpy(), + "indices": self.G_topk.numpy().astype(np.int32), + "rw": True, + "degree": self.number_of_neighbors, + "iteration": self.communication_round, + "visited": [self.uid], + "fuel": self.rw_length - 1, + "send_partial": True, + } + logging.info("new rw message") + self.communication.send(None, rw_data) + + rw_chance = self.rw_chance + self.serialized_model() # dummy call to get self.G_topK + while rw_chance >= 1.0: + # TODO: make sure they are not sent to the same neighbour + send() + rw_chance -= 1 + rw_now = torch.rand(size=(1,), generator=self.random_generator).item() + if rw_now < rw_chance: + send() + + + + def save_vector(self, v, s): + """ + Saves the given vector to the file. + + Parameters + ---------- + v : torch.tensor + The torch tensor to write to file + s : str + Path to folder to write to + + """ + output_dict = dict() + output_dict["order"] = list(self.model.state_dict().keys()) + shapes = dict() + for k, v1 in self.model.state_dict().items(): + shapes[k] = list(v1.shape) + output_dict["shapes"] = shapes + + output_dict["tensor"] = v.tolist() + + with open( + os.path.join( + s, + "{}.json".format(self.communication_round + 1), + ), + "w", + ) as of: + json.dump(output_dict, of) + + def save_change(self): + """ + Saves the change and the gradient values for every iteration + + """ + self.save_vector(self.model.model_change, self.model_change_path) -- GitLab