diff --git a/eval/run_xtimes_reddit_rws.sh b/eval/run_xtimes_reddit_rws.sh
index 37265011116a3be860309a402c5079829d29b600..704055f3e88d6a48e6ee17fdc58af08292917f32 100755
--- a/eval/run_xtimes_reddit_rws.sh
+++ b/eval/run_xtimes_reddit_rws.sh
@@ -54,7 +54,7 @@ export PYTHONFAULTHANDLER=1
 # Base configs for which the gird search is done
 # tests=("step_configs/config_reddit_sharing_topKdynamicGraph.ini")
 # tests=("step_configs/config_reddit_sharing_topKsharingasyncrw.ini" "step_configs/config_reddit_sharing_topKdpsgdrwasync.ini" "step_configs/config_reddit_sharing_topKdpsgdrw.ini")
-tests=("step_configs/config_reddit_sharing_dpsgdrwasync0.ini")
+tests=("step_configs/config_reddit_sharing_dynamicGraph.ini")# ("step_configs/config_reddit_sharing_dpsgdrwasync0.ini")
 # tests=("step_configs/config_reddit_sharing_dpsgdrw.ini" "step_configs/config_reddit_sharing_dpsgdrwasync.ini" "step_configs/config_reddit_sharing_sharingasyncrw.ini" "step_configs/config_reddit_sharing_sharingrw.ini")
 # Learning rates
 lr="1"
diff --git a/eval/step_configs/config_reddit_sharing_dynamicGraph.ini b/eval/step_configs/config_reddit_sharing_dynamicGraph.ini
index 44849cc1440edc15b747a2289f831405b3b790e8..3caeabb95177deb802b4a229eb729f7ae3c81bfd 100644
--- a/eval/step_configs/config_reddit_sharing_dynamicGraph.ini
+++ b/eval/step_configs/config_reddit_sharing_dynamicGraph.ini
@@ -32,4 +32,4 @@ sampler = equi
 [SHARING]
 sharing_package = decentralizepy.sharing.SharingDynamicGraph
 sharing_class = SharingDynamicGraph
-avg = True
\ No newline at end of file
+avg = False
\ No newline at end of file
diff --git a/src/decentralizepy/communication/TCPRandomWalk.py b/src/decentralizepy/communication/TCPRandomWalk.py
index 6dee4dbbe18173491b9f09e452224643cadf87dd..10dcdfb76b7f7cda55d40332189eda23776a9730 100644
--- a/src/decentralizepy/communication/TCPRandomWalk.py
+++ b/src/decentralizepy/communication/TCPRandomWalk.py
@@ -129,17 +129,23 @@ class TCPRandomWalkBase(Communication):
         logging.debug("in encrypt")
         if self.compress:
             logging.debug("in encrypt: compress")
-            if "indices" in data:
-                data["indices"] = self.compressor.compress(data["indices"])
+            if type(data) == dict:
+                if "indices" in data:
+                    data["indices"] = self.compressor.compress(data["indices"])
 
-            assert "params" in data
-            data["params"] = self.compressor.compress_float(data["params"])
-            data_len = len(pickle.dumps(data["params"]))
-            output = pickle.dumps(data)
+                if "params" in data:
+                    data["params"] = self.compressor.compress_float(data["params"])
+                    data_len = len(pickle.dumps(data["params"]))
+                else:
+                    data_len = 0
+                output = pickle.dumps(data)
+
+                # the compressed meta data gets only a few bytes smaller after pickling
+                self.total_meta.value += (len(output) - data_len)
+                self.total_data.value += data_len
+            else:
+                output = pickle.dumps(data)
 
-            # the compressed meta data gets only a few bytes smaller after pickling
-            self.total_meta.value += (len(output) - data_len)
-            self.total_data.value += data_len
         else:
             logging.debug("in encrypt: else")
             output = pickle.dumps(data)
@@ -175,10 +181,15 @@ class TCPRandomWalkBase(Communication):
         if self.compress:
             logging.debug("in decrypt:comp")
             data = pickle.loads(data)
-            if "indices" in data:
-                data["indices"] = self.compressor.decompress(data["indices"])
-            if "params" in data:
-                data["params"] = self.compressor.decompress_float(data["params"])
+            if type(data) == dict:
+                if "indices" in data:
+                    data["indices"] = self.compressor.decompress(data["indices"])
+                if "params" in data:
+                    try:
+                        data["params"] = self.compressor.decompress_float(data["params"])
+                    except:
+                        print(f"faled for {data}")
+
         else:
             logging.debug("in decrypt:else")
             data = pickle.loads(data)
diff --git a/src/decentralizepy/sharing/JwinsDynamicGraph.py b/src/decentralizepy/sharing/JwinsDynamicGraph.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1d6292d487e0cbea908297657e466cbfb627b9d
--- /dev/null
+++ b/src/decentralizepy/sharing/JwinsDynamicGraph.py
@@ -0,0 +1,511 @@
+import json
+import logging
+import os
+from pathlib import Path
+
+import numpy as np
+import pywt
+import torch
+from decentralizepy.sharing.SharingDynamicGraph import SharingDynamicGraph
+from decentralizepy.utils import conditional_value, identity
+
+def change_transformer_wavelet(x, wavelet, level):
+    """
+    Transforms the model changes into wavelet frequency domain
+
+    Parameters
+    ----------
+    x : torch.Tensor
+        Model change in the space domain
+    wavelet : str
+        name of the wavelet to be used in gradient compression
+    level: int
+        name of the wavelet to be used in gradient compression
+
+    Returns
+    -------
+    x : torch.Tensor
+        Representation of the change int the wavelet domain
+    """
+    coeff = pywt.wavedec(x, wavelet, level=level)
+    data, coeff_slices = pywt.coeffs_to_array(coeff)
+    return torch.from_numpy(data.ravel())
+
+class JwinsDynamicGraph(SharingDynamicGraph):
+    """
+    This class implements the vanilla version of partial model sharing.
+
+    """
+
+    def __init__(
+        self,
+        rank,
+        machine_id,
+        communication,
+        mapping,
+        graph,
+        model,
+        dataset,
+        log_dir,
+        alpha=1.0,
+        dict_ordered=True,
+        save_shared=False,
+        metadata_cap=1.0,
+        wavelet="haar",
+        level=4,
+        change_based_selection=True,
+        save_accumulated="",
+        accumulation=False,
+        accumulate_averaging_changes=False,
+        rw_chance=0.1,
+        rw_length=4,
+        neighbor_bound=(3, 5),
+        change_topo_interval=10,
+        avg = False,
+        lower_bound=0.1,
+        metro_hastings=True,
+    ):
+        """
+        Constructor
+
+        Parameters
+        ----------
+        rank : int
+            Local rank
+        machine_id : int
+            Global machine id
+        communication : decentralizepy.communication.Communication
+            Communication module used to send and receive messages
+        mapping : decentralizepy.mappings.Mapping
+            Mapping (rank, machine_id) -> uid
+        graph : decentralizepy.graphs.Graph
+            Graph reprensenting neighbors
+        model : decentralizepy.models.Model
+            Model to train
+        dataset : decentralizepy.datasets.Dataset
+            Dataset for sharing data. Not implemented yet! TODO
+        log_dir : str
+            Location to write shared_params (only writing for 2 procs per machine)
+        alpha : float
+            Percentage of model to share
+        dict_ordered : bool
+            Specifies if the python dict maintains the order of insertion
+        save_shared : bool
+            Specifies if the indices of shared parameters should be logged
+        metadata_cap : float
+            Share full model when self.alpha > metadata_cap
+        wavelet: str
+            name of the wavelet to be used in gradient compression
+        level: int
+            name of the wavelet to be used in gradient compression
+        change_based_selection : bool
+            use frequency change to select topk frequencies
+        save_accumulated : bool
+            True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
+            the accumulated change is stored.
+        accumulation : bool
+            True if the the indices to share should be selected based on accumulated frequency change
+        save_accumulated : bool
+            True if accumulated weight change should be written to file. In case of accumulation the accumulated change
+            is stored. If a change_transformer is used then the transformed change is stored.
+        change_transformer : (x: Tensor) -> Tensor
+            A function that transforms the model change into other domains. Default: identity function
+        accumulate_averaging_changes: bool
+            True if the accumulation should account the model change due to averaging
+        """
+        self.wavelet = wavelet
+        self.level = level
+
+        super().__init__(
+            rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, neighbor_bound, change_topo_interval, avg
+        )
+        self.alpha = alpha
+        self.dict_ordered = dict_ordered
+        self.save_shared = save_shared
+        self.metadata_cap = metadata_cap
+        self.accumulation = accumulation
+        self.save_accumulated = conditional_value(save_accumulated, "", False)
+        self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level)
+        self.accumulate_averaging_changes = accumulate_averaging_changes
+
+        # getting the initial model
+        self.shapes = []
+        self.lens = []
+        with torch.no_grad():
+            tensors_to_cat = []
+            for _, v in self.model.state_dict().items():
+                self.shapes.append(v.shape)
+                t = v.flatten()
+                self.lens.append(t.shape[0])
+                tensors_to_cat.append(t)
+            self.init_model = torch.cat(tensors_to_cat, dim=0)
+            if self.accumulation:
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.change_transformer(self.init_model)
+                )
+                self.prev = self.init_model
+        self.number_of_params = self.init_model.shape[0]
+        if self.save_accumulated:
+            self.model_change_path = os.path.join(
+                self.log_dir, "model_change/{}".format(self.rank)
+            )
+            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
+
+            self.model_val_path = os.path.join(
+                self.log_dir, "model_val/{}".format(self.rank)
+            )
+            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
+
+        # Only save for 2 procs: Save space
+        if self.save_shared and not (rank == 0 or rank == 1):
+            self.save_shared = False
+
+        if self.save_shared:
+            self.folder_path = os.path.join(
+                self.log_dir, "shared_params/{}".format(self.rank)
+            )
+            Path(self.folder_path).mkdir(parents=True, exist_ok=True)
+
+        self.model.shared_parameters_counter = torch.zeros(
+            self.change_transformer(self.init_model).shape[0], dtype=torch.int32
+        )
+
+        self.change_based_selection = change_based_selection
+
+        # Do a dummy transform to get the shape and coefficents slices
+        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
+        data, coeff_slices = pywt.coeffs_to_array(coeff)
+        self.wt_shape = data.shape
+        self.coeff_slices = coeff_slices
+
+        self.lower_bound = lower_bound
+        self.metro_hastings = metro_hastings
+        if self.lower_bound > 0:
+            self.start_lower_bounding_at = 1 / self.lower_bound
+
+    def apply_wavelet(self):
+        """
+        Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
+        based on the undergone change during the current training step
+
+        Returns
+        -------
+        tuple
+            (a,b). a: selected wavelet coefficients, b: Their indices.
+
+        """
+
+        data = self.pre_share_model_transformed
+        if self.change_based_selection:
+            diff = self.model.model_change
+            _, index = torch.topk(
+                diff.abs(),
+                round(self.alpha * len(diff)),
+                dim=0,
+                sorted=False,
+            )
+        else:
+            _, index = torch.topk(
+                data.abs(),
+                round(self.alpha * len(data)),
+                dim=0,
+                sorted=False,
+            )
+        ind, _ = torch.sort(index)
+
+        if self.lower_bound == 0.0:
+            return data[ind].clone().detach(), ind
+
+        if self.communication_round > self.start_lower_bounding_at:
+            # because the superclass increases it where it is inconvenient for this subclass
+            currently_shared = self.model.shared_parameters_counter.clone().detach()
+            currently_shared[ind] += 1
+            ind_small = (
+                currently_shared < self.communication_round * self.lower_bound
+            ).nonzero(as_tuple=True)[0]
+            ind_small_unique = np.setdiff1d(
+                ind_small.numpy(), ind.numpy(), assume_unique=True
+            )
+            take_max = round(self.lower_bound * self.alpha * data.shape[0])
+            logging.info(
+                "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max
+            )
+            if take_max > ind_small_unique.shape[0]:
+                take_max = ind_small_unique.shape[0]
+            to_take = torch.rand(ind_small_unique.shape[0])
+            _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False)
+            ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take]
+            logging.info("lower bounding: %i %i", len(ind), len(ind_bound))
+            # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
+            ind = torch.cat([ind, ind_bound])
+
+        index, _ = torch.sort(ind)
+        return data[index].clone().detach(), index
+
+    def serialized_model(self):
+        """
+        Convert model to json dict. self.alpha specifies the fraction of model to send.
+
+        Returns
+        -------
+        dict
+            Model converted to json dict
+
+        """
+        m = dict()
+        if self.alpha >= self.metadata_cap:  # Share fully
+            data = self.pre_share_model_transformed
+            m["params"] = data.numpy()
+            if self.model.accumulated_changes is not None:
+                self.model.accumulated_changes = torch.zeros_like(
+                    self.model.accumulated_changes
+                )
+            return m
+
+        with torch.no_grad():
+            topk, self.G_topk = self.apply_wavelet()
+            self.model.shared_parameters_counter[self.G_topk] += 1
+            self.model.rewind_accumulation(self.G_topk)
+            if self.save_shared:
+                shared_params = dict()
+                shared_params["order"] = list(self.model.state_dict().keys())
+                shapes = dict()
+                for k, v in self.model.state_dict().items():
+                    shapes[k] = list(v.shape)
+                shared_params["shapes"] = shapes
+
+                shared_params[self.communication_round] = self.G_topk.tolist()
+
+                with open(
+                    os.path.join(
+                        self.folder_path,
+                        "{}_shared_params.json".format(self.communication_round + 1),
+                    ),
+                    "w",
+                ) as of:
+                    json.dump(shared_params, of)
+
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            m["alpha"] = self.alpha
+
+            m["indices"] = self.G_topk.numpy().astype(np.int32)
+
+            m["params"] = topk.numpy()
+
+            m["send_partial"] = True
+
+            assert len(m["indices"]) == len(m["params"])
+            logging.info("Elements sending: {}".format(len(m["indices"])))
+
+            logging.info("Generated dictionary to send")
+
+            logging.info("Converted dictionary to pickle")
+
+            return m
+
+
+    def deserialized_model(self, m):
+        """
+        Convert received dict to state_dict.
+
+        Parameters
+        ----------
+        m : dict
+            received dict
+
+        Returns
+        -------
+        state_dict
+            state_dict of received
+
+        """
+        ret = dict()
+        if "send_partial" not in m:
+            params = m["params"]
+            params_tensor = torch.tensor(params)
+            ret["params"] = params_tensor
+            return ret
+
+        with torch.no_grad():
+            if not self.dict_ordered:
+                raise NotImplementedError
+
+            params_tensor = torch.tensor(m["params"])
+            indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
+            ret = dict()
+            ret["indices"] = indices_tensor
+            ret["params"] = params_tensor
+            ret["send_partial"] = True
+        return ret
+
+    def _averaging(self):
+        """
+        Averages the received model with the local model
+
+        """
+        with torch.no_grad():
+            total = None
+            weight_total = 0
+            wt_params = self.pre_share_model_transformed
+            if not self.metro_hastings:
+                weight_vector = torch.ones_like(wt_params)
+                datas = []
+            batch = self._preprocessing_received_models()
+            for n, vals in batch.items():
+                if len(vals) > 1:
+                    # this should never happen
+                    logging.info("GOT double message in dynamic graph!")
+                else:
+                    degree, iteration, data = vals[0]
+                #degree, iteration, data = self.peer_deques[n].popleft()
+                logging.debug(
+                    "Averaging model from neighbor {} of iteration {}".format(
+                        n, iteration
+                    )
+                )
+                #data = self.deserialized_model(data)
+                params = data["params"]
+                if "indices" in data:
+                    indices = data["indices"]
+                    if not self.metro_hastings:
+                        weight_vector[indices] += 1
+                        topkwf = torch.zeros_like(wt_params)
+                        topkwf[indices] = params
+                        topkwf = topkwf.reshape(self.wt_shape)
+                        datas.append(topkwf)
+                    else:
+                        # use local data to complement
+                        topkwf = wt_params.clone().detach()
+                        topkwf[indices] = params
+                        topkwf = topkwf.reshape(self.wt_shape)
+
+                else:
+                    topkwf = params.reshape(self.wt_shape)
+                    if not self.metro_hastings:
+                        weight_vector += 1
+                        datas.append(topkwf)
+
+                if self.metro_hastings:
+                    weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
+                    weight_total += weight
+                    if total is None:
+                        total = weight * topkwf
+                    else:
+                        total += weight * topkwf
+            if not self.metro_hastings:
+                weight_vector = 1.0 / weight_vector
+                # speed up by exploiting sparsity
+                total = wt_params * weight_vector
+                for d in datas:
+                    total += d * weight_vector
+            else:
+                # Metro-Hastings
+                total += (1 - weight_total) * wt_params
+
+            avg_wf_params = pywt.array_to_coeffs(
+                total.numpy(), self.coeff_slices, output_format="wavedec"
+            )
+            reverse_total = torch.from_numpy(
+                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
+            )
+
+            start_index = 0
+            std_dict = {}
+            for i, key in enumerate(self.model.state_dict()):
+                end_index = start_index + self.lens[i]
+                std_dict[key] = reverse_total[start_index:end_index].reshape(
+                    self.shapes[i]
+                )
+                start_index = end_index
+
+        self.model.load_state_dict(std_dict)
+
+
+    def _pre_step(self):
+        """
+        Called at the beginning of step.
+
+        """
+        logging.info("PartialModel _pre_step")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            self.pre_share_model = torch.cat(tensors_to_cat, dim=0)
+            # Would only need one of the transforms
+            self.pre_share_model_transformed = self.change_transformer(
+                self.pre_share_model
+            )
+            change = self.change_transformer(self.pre_share_model - self.init_model)
+            if self.accumulation:
+                if not self.accumulate_averaging_changes:
+                    # Need to accumulate in _pre_step as the accumulation gets rewind during the step
+                    self.model.accumulated_changes += change
+                    change = self.model.accumulated_changes.clone().detach()
+                else:
+                    # For the legacy implementation, we will only rewind currently accumulated values
+                    # and add the model change due to averaging in the end
+                    change += self.model.accumulated_changes
+            # stores change of the model due to training, change due to averaging is not accounted
+            self.model.model_change = change
+
+    def _post_step(self):
+        """
+        Called at the end of step.
+
+        """
+        logging.info("PartialModel _post_step")
+        with torch.no_grad():
+            tensors_to_cat = [
+                v.data.flatten() for _, v in self.model.state_dict().items()
+            ]
+            post_share_model = torch.cat(tensors_to_cat, dim=0)
+            self.init_model = post_share_model
+            if self.accumulation:
+                if self.accumulate_averaging_changes:
+                    self.model.accumulated_changes += self.change_transformer(
+                        self.init_model - self.prev
+                    )
+                self.prev = self.init_model
+            self.model.model_change = None
+        if self.save_accumulated:
+            self.save_change()
+
+    def save_vector(self, v, s):
+        """
+        Saves the given vector to the file.
+
+        Parameters
+        ----------
+        v : torch.tensor
+            The torch tensor to write to file
+        s : str
+            Path to folder to write to
+
+        """
+        output_dict = dict()
+        output_dict["order"] = list(self.model.state_dict().keys())
+        shapes = dict()
+        for k, v1 in self.model.state_dict().items():
+            shapes[k] = list(v1.shape)
+        output_dict["shapes"] = shapes
+
+        output_dict["tensor"] = v.tolist()
+
+        with open(
+            os.path.join(
+                s,
+                "{}.json".format(self.communication_round + 1),
+            ),
+            "w",
+        ) as of:
+            json.dump(output_dict, of)
+
+    def save_change(self):
+        """
+        Saves the change and the gradient values for every iteration
+
+        """
+        self.save_vector(self.model.model_change, self.model_change_path)
diff --git a/src/decentralizepy/sharing/SharingDynamicGraph.py b/src/decentralizepy/sharing/SharingDynamicGraph.py
index 25647251d76017506e8d4c50c423510fdbfc17a8..0c744dde3997fedeee4f8e60ffa3424f594a0637 100644
--- a/src/decentralizepy/sharing/SharingDynamicGraph.py
+++ b/src/decentralizepy/sharing/SharingDynamicGraph.py
@@ -31,7 +31,7 @@ class SharingDynamicGraph(DPSGDRW):
         rw_length=6,
         neighbor_bound=(3, 5),
         change_topo_interval=10,
-        avg = True
+        avg = False
     ):
         """
         Constructor
@@ -70,7 +70,7 @@ class SharingDynamicGraph(DPSGDRW):
             rw_chance,
             rw_length,
         )
-        self.avg = True
+        self.avg = avg
         self.neighbor_bound_lower = neighbor_bound[0]
         self.neighbor_bound_higher = neighbor_bound[1]
         self.change_topo_interval = change_topo_interval