From 4c73b35afe6b5dfb3eab13d70570ee3c5b826338 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Wed, 15 Jun 2022 17:55:26 +0200
Subject: [PATCH] fixes jwins random walk

---
 src/decentralizepy/node/Node.py               |  1 +
 src/decentralizepy/sharing/JwinsDPSGDAsync.py | 22 ++++---------------
 2 files changed, 5 insertions(+), 18 deletions(-)

diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 8536b5c..033c1e5 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -519,6 +519,7 @@ class Node:
             ) as of:
                 json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
         logging.info("disconnect neighbors")
+        print("Node: disconnect")
         self.communication.disconnect_neighbors()
         logging.info("Storing final weight")
         # self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
diff --git a/src/decentralizepy/sharing/JwinsDPSGDAsync.py b/src/decentralizepy/sharing/JwinsDPSGDAsync.py
index b0e0552..d40ceea 100644
--- a/src/decentralizepy/sharing/JwinsDPSGDAsync.py
+++ b/src/decentralizepy/sharing/JwinsDPSGDAsync.py
@@ -353,25 +353,11 @@ class JwinsDPSGDAsync(DPSGDRWAsync):
                 weight_vector = torch.ones_like(wt_params)
                 datas = []
             batch = self._preprocessing_received_models()
+            new_batch = []
             for n, vals in batch.items():
-                if len(vals) > 1:
-                    data = None
-                    degree = 0
-                    # this should no longer happen, unless we get two rw from the same originator
-                    logging.info("averaging double messages for %i", n)
-                    for val in vals:
-                        degree_sub, iteration, data_sub = val
-                        if data is None:
-                            data = data_sub
-                            degree = degree
-                        else:
-                            for key, weight_val in data_sub.items():
-                                data[key] += weight_val
-                            degree = max(degree, degree_sub)
-                    for key, weight_val in data.items():
-                        data[key] /= len(vals)
-                else:
-                    degree, iteration, data = vals[0]
+                new_batch.extend(vals)
+            for vals in new_batch:
+                degree, iteration, data = vals
                 #degree, iteration, data = self.peer_deques[n].popleft()
                 logging.debug(
                     "Averaging model from neighbor {} of iteration {}".format(
-- 
GitLab