From 3be414ab24e9486b97afee7e120be9efb382aafd Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Wed, 15 Jun 2022 10:31:19 +0200
Subject: [PATCH] random walk jwins

---
 eval/run_xtimes_reddit_jwins.sh               |   6 +-
 eval/run_xtimes_reddit_rws.sh                 |   4 +-
 .../config_reddit_jwins+_local.ini            |  45 ++
 .../config_reddit_sharing_jwinsasync.ini      |  45 ++
 .../config_reddit_sharing_jwinsasync30.ini    |  45 ++
 src/decentralizepy/sharing/JwinsDPSGDAsync.py | 549 ++++++++++++++++++
 6 files changed, 689 insertions(+), 5 deletions(-)
 create mode 100644 eval/step_configs/config_reddit_jwins+_local.ini
 create mode 100644 eval/step_configs/config_reddit_sharing_jwinsasync.ini
 create mode 100644 eval/step_configs/config_reddit_sharing_jwinsasync30.ini
 create mode 100644 src/decentralizepy/sharing/JwinsDPSGDAsync.py

diff --git a/eval/run_xtimes_reddit_jwins.sh b/eval/run_xtimes_reddit_jwins.sh
index 84bcf8e..72e98fc 100755
--- a/eval/run_xtimes_reddit_jwins.sh
+++ b/eval/run_xtimes_reddit_jwins.sh
@@ -55,9 +55,9 @@ export PYTHONFAULTHANDLER=1
 # "step_configs/config_reddit_sharing.ini"
 tests=("step_configs/config_reddit_sharing.ini" "step_configs/config_reddit_jwins+.ini" "step_configs/config_reddit_topkacc.ini" "step_configs/config_reddit_subsampling.ini")
 # Learning rates
-lr="1"
+lr="0.1"
 # Batch size
-batchsize="16"
+batchsize="8"
 # The number of communication rounds per global epoch
 comm_rounds_per_global_epoch="10"
 procs=`expr $procs_per_machine \* $machines`
@@ -106,7 +106,7 @@ do
     $python_bin/crudini --set $config_file TRAIN_PARAMS rounds $batches_per_comm_round
     $python_bin/crudini --set $config_file TRAIN_PARAMS batch_size $batchsize
     $python_bin/crudini --set $config_file DATASET random_seed $seed
-    $env_python $eval_file -ro 0 -cte 1 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
+    $env_python $eval_file -ro 0 -cte 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
     sleep 300
     echo end of sleep
diff --git a/eval/run_xtimes_reddit_rws.sh b/eval/run_xtimes_reddit_rws.sh
index 7b75e8e..57594e3 100755
--- a/eval/run_xtimes_reddit_rws.sh
+++ b/eval/run_xtimes_reddit_rws.sh
@@ -42,7 +42,7 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=80
+global_epochs=160
 eval_file=testing.py
 log_level=DEBUG
 
@@ -54,7 +54,7 @@ export PYTHONFAULTHANDLER=1
 # Base configs for which the gird search is done
 # tests=("step_configs/config_reddit_sharing_topKdynamicGraph.ini")
 # tests=("step_configs/config_reddit_sharing_topKsharingasyncrw.ini" "step_configs/config_reddit_sharing_topKdpsgdrwasync.ini" "step_configs/config_reddit_sharing_topKdpsgdrw.ini")
-tests=("step_configs/config_reddit_sharing_dynamicGraphJwins30.ini") # ("step_configs/config_reddit_sharing_dpsgdrwasync0.ini")
+tests=("step_configs/config_reddit_sharing_jwinsasync.ini" "step_configs/config_reddit_sharing_jwinsasync30.ini") # ("step_configs/config_reddit_sharing_dpsgdrwasync0.ini")
 # tests=("step_configs/config_reddit_sharing_dpsgdrw.ini" "step_configs/config_reddit_sharing_dpsgdrwasync.ini" "step_configs/config_reddit_sharing_sharingasyncrw.ini" "step_configs/config_reddit_sharing_sharingrw.ini")
 # Learning rates
 lr="1"
diff --git a/eval/step_configs/config_reddit_jwins+_local.ini b/eval/step_configs/config_reddit_jwins+_local.ini
new file mode 100644
index 0000000..6411cef
--- /dev/null
+++ b/eval/step_configs/config_reddit_jwins+_local.ini
@@ -0,0 +1,45 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Reddit
+dataset_class = Reddit
+random_seed = 97
+model_class = RNN
+train_dir = /home/jeffrey/Downloads/reddit/per_user_data/train
+test_dir = /home/jeffrey/Downloads/reddit/new_small_data/test
+; python list of fractions below
+sizes =
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCPRandomWalk
+comm_class = TCPRandomWalk
+addresses_filepath = ip_addr_6Machines.json
+compression_package = decentralizepy.compression.Eliaszfplossy1
+compression_class = Eliaszfplossy1
+compress = True
+sampler = equi
+
+[SHARING]
+sharing_package = decentralizepy.sharing.JwinsDPSGDAsync
+sharing_class = JwinsDPSGDAsync
+alpha=0.0833
+lower_bound=0.2
+metro_hastings=False
+change_based_selection = True
+wavelet=sym2
+level= None
+accumulation = True
+accumulate_averaging_changes = True
\ No newline at end of file
diff --git a/eval/step_configs/config_reddit_sharing_jwinsasync.ini b/eval/step_configs/config_reddit_sharing_jwinsasync.ini
new file mode 100644
index 0000000..4a09852
--- /dev/null
+++ b/eval/step_configs/config_reddit_sharing_jwinsasync.ini
@@ -0,0 +1,45 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Reddit
+dataset_class = Reddit
+random_seed = 97
+model_class = RNN
+train_dir = /mnt/nfs/shared/leaf/data/reddit_new/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/reddit_new/new_small_data/test
+; python list of fractions below
+sizes =
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCPRandomWalk
+comm_class = TCPRandomWalk
+addresses_filepath = ip_addr_6Machines.json
+compression_package = decentralizepy.compression.Eliaszfplossy1
+compression_class = Eliaszfplossy1
+compress = True
+sampler = equi_check_history
+
+[SHARING]
+sharing_package = decentralizepy.sharing.JwinsDPSGDAsync
+sharing_class = JwinsDPSGDAsync
+alpha=0.0833
+lower_bound=0.2
+metro_hastings=False
+change_based_selection = True
+wavelet=sym2
+level= None
+accumulation = True
+accumulate_averaging_changes = True
diff --git a/eval/step_configs/config_reddit_sharing_jwinsasync30.ini b/eval/step_configs/config_reddit_sharing_jwinsasync30.ini
new file mode 100644
index 0000000..e006a57
--- /dev/null
+++ b/eval/step_configs/config_reddit_sharing_jwinsasync30.ini
@@ -0,0 +1,45 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.Reddit
+dataset_class = Reddit
+random_seed = 97
+model_class = RNN
+train_dir = /mnt/nfs/shared/leaf/data/reddit_new/per_user_data/train
+test_dir = /mnt/nfs/shared/leaf/data/reddit_new/new_small_data/test
+; python list of fractions below
+sizes =
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 47
+full_epochs = False
+batch_size = 16
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCPRandomWalk
+comm_class = TCPRandomWalk
+addresses_filepath = ip_addr_6Machines.json
+compression_package = decentralizepy.compression.Eliaszfplossy1
+compression_class = Eliaszfplossy1
+compress = True
+sampler = equi_check_history
+
+[SHARING]
+sharing_package = decentralizepy.sharing.JwinsDPSGDAsync
+sharing_class = JwinsDPSGDAsync
+alpha=0.25
+lower_bound=0.2
+metro_hastings=False
+change_based_selection = True
+wavelet=sym2
+level= None
+accumulation = True
+accumulate_averaging_changes = True
diff --git a/src/decentralizepy/sharing/JwinsDPSGDAsync.py b/src/decentralizepy/sharing/JwinsDPSGDAsync.py
new file mode 100644
index 0000000..b0e0552
--- /dev/null
+++ b/src/decentralizepy/sharing/JwinsDPSGDAsync.py
@@ -0,0 +1,549 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+import numpy as np
+import pywt
+import torch
+from decentralizepy.sharing.DPSGDRWAsync import DPSGDRWAsync
+from decentralizepy.utils import conditional_value, identity
+
+def change_transformer_wavelet(x, wavelet, level):
+    """
+    Transforms the model changes into wavelet frequency domain
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        Model change in the space domain
+    wavelet : str
+        name of the wavelet to be used in gradient compression
+    level: int
+        name of the wavelet to be used in gradient compression
+
+    Returns
+    -------
+    x : torch.Tensor
+        Representation of the change int the wavelet domain
+    """
+    coeff = pywt.wavedec(x, wavelet, level=level)
+    data, coeff_slices = pywt.coeffs_to_array(coeff)
+    return torch.from_numpy(data.ravel())
+
+class JwinsDPSGDAsync(DPSGDRWAsync):
+    """
+    This class implements the vanilla version of partial model sharing.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        wavelet="haar",
+        level=4,
+        change_based_selection=True,
+        save_accumulated="",
+        accumulation=False,
+        accumulate_averaging_changes=False,
+        rw_chance=1,
+        rw_length=4,
+        comm_interval=0.5,
+        min_interval=0.001,
+        max_lag=2,
+        lower_bound = 0.1,
+        metro_hastings=True,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        wavelet: str
+            name of the wavelet to be used in gradient compression
+        level: int
+            name of the wavelet to be used in gradient compression
+        change_based_selection : bool
+            use frequency change to select topk frequencies
+        save_accumulated : bool
+            True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
+            the accumulated change is stored.
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        save_accumulated : bool
+            True if accumulated weight change should be written to file. In case of accumulation the accumulated change
+            is stored. If a change_transformer is used then the transformed change is stored.
+        change_transformer : (x: Tensor) -> Tensor
+            A function that transforms the model change into other domains. Default: identity function
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
+        """
+        self.wavelet = wavelet
+        self.level = level
+
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, comm_interval, min_interval, max_lag
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.accumulation = accumulation
+        self.save_accumulated = conditional_value(save_accumulated, "", False)
+        self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level)
+        self.accumulate_averaging_changes = accumulate_averaging_changes
+
+        # getting the initial model
+        self.shapes = []
+        self.lens = []
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                self.shapes.append(v.shape)
+                t = v.flatten()
+                self.lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+            self.init_model = torch.cat(tensors_to_cat, dim=0)
+            if self.accumulation:
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.change_transformer(self.init_model)
+                )
+                self.prev = self.init_model
+        self.number_of_params = self.init_model.shape[0]
+        if self.save_accumulated:
+            self.model_change_path = os.path.join(
+                self.log_dir, "model_change/{}".format(self.rank)
+            )
+            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
+
+            self.model_val_path = os.path.join(
+                self.log_dir, "model_val/{}".format(self.rank)
+            )
+            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
+
+        # Only save for 2 procs: Save space
+        if self.save_shared and not (rank == 0 or rank == 1):
+            self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+        self.model.shared_parameters_counter = torch.zeros(
+            self.change_transformer(self.init_model).shape[0], dtype=torch.int32
+        )
+
+        self.change_based_selection = change_based_selection
+
+        # Do a dummy transform to get the shape and coefficents slices
+        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
+        data, coeff_slices = pywt.coeffs_to_array(coeff)
+        self.wt_shape = data.shape
+        self.coeff_slices = coeff_slices
+
+        self.lower_bound = lower_bound
+        self.metro_hastings = metro_hastings
+        if self.lower_bound > 0:
+            self.start_lower_bounding_at = 1 / self.lower_bound
+
+    def apply_wavelet(self):
+        """
+        Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
+        based on the undergone change during the current training step
+
+        Returns
+        -------
+        tuple
+            (a,b). a: selected wavelet coefficients, b: Their indices.
+
+        """
+
+        data = self.pre_share_model_transformed
+        if self.change_based_selection:
+            diff = self.model.model_change
+            _, index = torch.topk(
+                diff.abs(),
+                round(self.alpha * len(diff)),
+                dim=0,
+                sorted=False,
+            )
+        else:
+            _, index = torch.topk(
+                data.abs(),
+                round(self.alpha * len(data)),
+                dim=0,
+                sorted=False,
+            )
+        ind, _ = torch.sort(index)
+
+        if self.lower_bound == 0.0:
+            return data[ind].clone().detach(), ind
+
+        if self.communication_round > self.start_lower_bounding_at:
+            # because the superclass increases it where it is inconvenient for this subclass
+            currently_shared = self.model.shared_parameters_counter.clone().detach()
+            currently_shared[ind] += 1
+            ind_small = (
+                currently_shared < self.communication_round * self.lower_bound
+            ).nonzero(as_tuple=True)[0]
+            ind_small_unique = np.setdiff1d(
+                ind_small.numpy(), ind.numpy(), assume_unique=True
+            )
+            take_max = round(self.lower_bound * self.alpha * data.shape[0])
+            logging.info(
+                "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max
+            )
+            if take_max > ind_small_unique.shape[0]:
+                take_max = ind_small_unique.shape[0]
+            to_take = torch.rand(ind_small_unique.shape[0])
+            _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False)
+            ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take]
+            logging.info("lower bounding: %i %i", len(ind), len(ind_bound))
+            # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
+            ind = torch.cat([ind, ind_bound])
+
+        index, _ = torch.sort(ind)
+        return data[index].clone().detach(), index
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        m = dict()
+        if self.alpha >= self.metadata_cap:  # Share fully
+            data = self.pre_share_model_transformed
+            m["params"] = data.numpy()
+            if self.model.accumulated_changes is not None:
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.model.accumulated_changes
+                )
+            return m
+
+        with torch.no_grad():
+            topk, self.G_topk = self.apply_wavelet()
+            self.model.shared_parameters_counter[self.G_topk] += 1
+            self.model.rewind_accumulation(self.G_topk)
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = self.G_topk.tolist()
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["alpha"] = self.alpha
+
+            m["indices"] = self.G_topk.numpy().astype(np.int32)
+
+            m["params"] = topk.numpy()
+
+            m["send_partial"] = True
+
+            assert len(m["indices"]) == len(m["params"])
+            logging.info("Elements sending: {}".format(len(m["indices"])))
+
+            logging.info("Generated dictionary to send")
+
+            logging.info("Converted dictionary to pickle")
+
+            return m
+
+
+    def deserialized_model(self, m):
+        """
+        Convert received dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            received dict
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        ret = dict()
+        if "send_partial" not in m:
+            params = m["params"]
+            params_tensor = torch.tensor(params)
+            ret["params"] = params_tensor
+            return ret
+
+        with torch.no_grad():
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            params_tensor = torch.tensor(m["params"])
+            indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
+            ret = dict()
+            ret["indices"] = indices_tensor
+            ret["params"] = params_tensor
+            ret["send_partial"] = True
+        return ret
+
+    def _averaging(self):
+        """
+        Averages the received model with the local model
+
+        """
+        with torch.no_grad():
+            total = None
+            weight_total = 0
+            wt_params = self.pre_share_model_transformed
+            if not self.metro_hastings:
+                weight_vector = torch.ones_like(wt_params)
+                datas = []
+            batch = self._preprocessing_received_models()
+            for n, vals in batch.items():
+                if len(vals) > 1:
+                    data = None
+                    degree = 0
+                    # this should no longer happen, unless we get two rw from the same originator
+                    logging.info("averaging double messages for %i", n)
+                    for val in vals:
+                        degree_sub, iteration, data_sub = val
+                        if data is None:
+                            data = data_sub
+                            degree = degree
+                        else:
+                            for key, weight_val in data_sub.items():
+                                data[key] += weight_val
+                            degree = max(degree, degree_sub)
+                    for key, weight_val in data.items():
+                        data[key] /= len(vals)
+                else:
+                    degree, iteration, data = vals[0]
+                #degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                #data = self.deserialized_model(data)
+                params = data["params"]
+                if "indices" in data:
+                    indices = data["indices"]
+                    if not self.metro_hastings:
+                        weight_vector[indices] += 1
+                        topkwf = torch.zeros_like(wt_params)
+                        topkwf[indices] = params
+                        topkwf = topkwf.reshape(self.wt_shape)
+                        datas.append(topkwf)
+                    else:
+                        # use local data to complement
+                        topkwf = wt_params.clone().detach()
+                        topkwf[indices] = params
+                        topkwf = topkwf.reshape(self.wt_shape)
+
+                else:
+                    topkwf = params.reshape(self.wt_shape)
+                    if not self.metro_hastings:
+                        weight_vector += 1
+                        datas.append(topkwf)
+
+                if self.metro_hastings:
+                    weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                    weight_total += weight
+                    if total is None:
+                        total = weight * topkwf
+                    else:
+                        total += weight * topkwf
+            if not self.metro_hastings:
+                weight_vector = 1.0 / weight_vector
+                # speed up by exploiting sparsity
+                total = wt_params * weight_vector
+                for d in datas:
+                    total += d * weight_vector
+            else:
+                # Metro-Hastings
+                total += (1 - weight_total) * wt_params
+
+            avg_wf_params = pywt.array_to_coeffs(
+                total.numpy(), self.coeff_slices, output_format="wavedec"
+            )
+            reverse_total = torch.from_numpy(
+                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
+            )
+
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(
+                    self.shapes[i]
+                )
+                start_index = end_index
+
+        self.model.load_state_dict(std_dict)
+        to_cat = []
+        for _, v in self.model.state_dict().items():
+            vf = v.clone().detach().flatten()
+            to_cat.append(vf)
+        self.init_model = torch.cat(to_cat)
+        self._transformed = self.change_transformer(self.init_model)
+
+
+    def _pre_step(self):
+        """
+        Called at the beginning of step.
+
+        """
+        logging.info("PartialModel _pre_step")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            self.pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            # Would only need one of the transforms
+            self.pre_share_model_transformed = self.change_transformer(
+                self.pre_share_model
+            )
+
+
+    def _post_step(self):
+        """
+        Called at the end of step.
+
+        """
+        change = self.change_transformer(self.pre_share_model - self.init_model) # self.init_model is set in _averaging
+        if self.accumulation:
+            # Need to accumulate in _pre_step as the accumulation gets rewind during the step
+            self.model.accumulated_changes += change
+            change = self.model.accumulated_changes.clone().detach()
+        # stores change of the model due to training, change due to averaging is not accounted
+        self.model.model_change = change
+
+
+    def _send_rw(self):
+        def send():
+            # will have to send the data twice to make the code simpler (for the beginning)
+            if self.alpha >= self.metadata_cap:
+                rw_data = {
+                    "params": self.init_model.numpy(),
+                    "rw": True,
+                    "degree": self.number_of_neighbors,
+                    "iteration": self.communication_round,
+                    "visited": [self.uid],
+                    "fuel": self.rw_length - 1,
+                }
+            else:
+                rw_data = {
+                    "params": self._transformed[self.G_topk].numpy(),
+                    "indices": self.G_topk.numpy().astype(np.int32),
+                    "rw": True,
+                    "degree": self.number_of_neighbors,
+                    "iteration": self.communication_round,
+                    "visited": [self.uid],
+                    "fuel": self.rw_length - 1,
+                    "send_partial": True,
+                }
+            logging.info("new rw message")
+            self.communication.send(None, rw_data)
+
+        rw_chance = self.rw_chance
+        self.serialized_model() # dummy call to get self.G_topK
+        while rw_chance >= 1.0:
+            # TODO: make sure they are not sent to the same neighbour
+            send()
+            rw_chance -= 1
+        rw_now = torch.rand(size=(1,), generator=self.random_generator).item()
+        if rw_now < rw_chance:
+            send()
+
+
+
+    def save_vector(self, v, s):
+        """
+        Saves the given vector to the file.
+
+        Parameters
+        ----------
+        v : torch.tensor
+            The torch tensor to write to file
+        s : str
+            Path to folder to write to
+
+        """
+        output_dict = dict()
+        output_dict["order"] = list(self.model.state_dict().keys())
+        shapes = dict()
+        for k, v1 in self.model.state_dict().items():
+            shapes[k] = list(v1.shape)
+        output_dict["shapes"] = shapes
+
+        output_dict["tensor"] = v.tolist()
+
+        with open(
+            os.path.join(
+                s,
+                "{}.json".format(self.communication_round + 1),
+            ),
+            "w",
+        ) as of:
+            json.dump(output_dict, of)
+
+    def save_change(self):
+        """
+        Saves the change and the gradient values for every iteration
+
+        """
+        self.save_vector(self.model.model_change, self.model_change_path)
-- 
GitLab