From 4d5daa73e6cdda3ebcaf1d487b802ba08558a650 Mon Sep 17 00:00:00 2001
From: Jeffrey Wigger <jeffrey.wigger@epfl.ch>
Date: Sun, 12 Jun 2022 05:00:30 +0200
Subject: [PATCH] cifar extreme

---
 eval/run_xtimes_cifar.sh                      |  8 ++--
 .../config_cifar_dpsgdWithRWAsync1.ini        | 37 +++++++++++++++++++
 .../config_cifar_dpsgdWithRWAsync2.ini        | 37 +++++++++++++++++++
 ...ini => config_cifar_dpsgdWithRWAsync4.ini} |  3 +-
 src/decentralizepy/node/Node.py               |  2 +
 5 files changed, 82 insertions(+), 5 deletions(-)
 create mode 100644 eval/step_configs/config_cifar_dpsgdWithRWAsync1.ini
 create mode 100644 eval/step_configs/config_cifar_dpsgdWithRWAsync2.ini
 rename eval/step_configs/{config_cifar_dpsgdWithRWAsync.ini => config_cifar_dpsgdWithRWAsync4.ini} (95%)

diff --git a/eval/run_xtimes_cifar.sh b/eval/run_xtimes_cifar.sh
index 0ee7962..54aa87a 100755
--- a/eval/run_xtimes_cifar.sh
+++ b/eval/run_xtimes_cifar.sh
@@ -42,7 +42,7 @@ graph=96_regular.edges
 config_file=~/tmp/config.ini
 procs_per_machine=16
 machines=6
-global_epochs=800
+global_epochs=1000
 eval_file=testing.py
 log_level=INFO
 
@@ -52,7 +52,7 @@ m=`cat $ip_machines | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print
 export PYTHONFAULTHANDLER=1
 
 # Base configs for which the gird search is done
-tests=("step_configs/config_cifar_sharing.ini") #"step_configs/config_cifar_partialmodel.ini" "step_configs/config_cifar_topkacc.ini" "step_configs/config_cifar_subsampling.ini" "step_configs/config_cifar_wavelet.ini")
+tests=("step_configs/config_cifar_sharing.ini" "step_configs/config_cifar_dpsgdWithRWAsync1.ini" "step_configs/config_cifar_dpsgdWithRWAsync2.ini" "step_configs/config_cifar_dpsgdWithRWAsync4.ini") #"step_configs/config_cifar_partialmodel.ini" "step_configs/config_cifar_topkacc.ini" "step_configs/config_cifar_subsampling.ini" "step_configs/config_cifar_wavelet.ini")
 # Learning rates
 lr="0.01"
 # Batch size
@@ -68,7 +68,7 @@ echo samples per user: $samples_per_user
 
 # random_seeds for which to rerun the experiments
 # random_seeds=("90" "91" "92" "93" "94")
-random_seeds=("97")
+random_seeds=("90" "91" "92")
 echo batchsize: $batchsize
 echo communication rounds per global epoch: $comm_rounds_per_global_epoch
 # calculating how many batches there are in a global epoch for each user/proc
@@ -107,7 +107,7 @@ do
     $python_bin/crudini --set $config_file DATASET random_seed $seed
     $env_python $eval_file -ro 0 -tea $test_after -ld $log_dir -wsd $weight_store_dir -mid $m -ps $procs_per_machine -ms $machines -is $new_iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
     echo $i is done
-    sleep 200
+    sleep 300
     echo end of sleep
     done
 done
diff --git a/eval/step_configs/config_cifar_dpsgdWithRWAsync1.ini b/eval/step_configs/config_cifar_dpsgdWithRWAsync1.ini
new file mode 100644
index 0000000..e83e367
--- /dev/null
+++ b/eval/step_configs/config_cifar_dpsgdWithRWAsync1.ini
@@ -0,0 +1,37 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.CIFAR10
+dataset_class = CIFAR10
+model_class = LeNet
+train_dir = /mnt/nfs/shared/CIFAR
+test_dir = /mnt/nfs/shared/CIFAR
+; python list of fractions below
+sizes =
+random_seed = 99
+partition_niid = True
+shards = 1
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 65
+full_epochs = False
+batch_size = 8
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCPRandomWalk
+comm_class = TCPRandomWalk
+addresses_filepath = ip_addr_6Machines.json
+sampler = equi_check_history
+
+[SHARING]
+sharing_package = decentralizepy.sharing.DPSGDRWAsync
+sharing_class = DPSGDRWAsync
+rw_chance=0.25
diff --git a/eval/step_configs/config_cifar_dpsgdWithRWAsync2.ini b/eval/step_configs/config_cifar_dpsgdWithRWAsync2.ini
new file mode 100644
index 0000000..b21b323
--- /dev/null
+++ b/eval/step_configs/config_cifar_dpsgdWithRWAsync2.ini
@@ -0,0 +1,37 @@
+[DATASET]
+dataset_package = decentralizepy.datasets.CIFAR10
+dataset_class = CIFAR10
+model_class = LeNet
+train_dir = /mnt/nfs/shared/CIFAR
+test_dir = /mnt/nfs/shared/CIFAR
+; python list of fractions below
+sizes =
+random_seed = 99
+partition_niid = True
+shards = 1
+
+[OPTIMIZER_PARAMS]
+optimizer_package = torch.optim
+optimizer_class = SGD
+lr = 0.001
+
+[TRAIN_PARAMS]
+training_package = decentralizepy.training.Training
+training_class = Training
+rounds = 65
+full_epochs = False
+batch_size = 8
+shuffle = True
+loss_package = torch.nn
+loss_class = CrossEntropyLoss
+
+[COMMUNICATION]
+comm_package = decentralizepy.communication.TCPRandomWalk
+comm_class = TCPRandomWalk
+addresses_filepath = ip_addr_6Machines.json
+sampler = equi_check_history
+
+[SHARING]
+sharing_package = decentralizepy.sharing.DPSGDRWAsync
+sharing_class = DPSGDRWAsync
+rw_chance=0.5
diff --git a/eval/step_configs/config_cifar_dpsgdWithRWAsync.ini b/eval/step_configs/config_cifar_dpsgdWithRWAsync4.ini
similarity index 95%
rename from eval/step_configs/config_cifar_dpsgdWithRWAsync.ini
rename to eval/step_configs/config_cifar_dpsgdWithRWAsync4.ini
index e11f86b..967ebfd 100644
--- a/eval/step_configs/config_cifar_dpsgdWithRWAsync.ini
+++ b/eval/step_configs/config_cifar_dpsgdWithRWAsync4.ini
@@ -29,8 +29,9 @@ loss_class = CrossEntropyLoss
 comm_package = decentralizepy.communication.TCPRandomWalk
 comm_class = TCPRandomWalk
 addresses_filepath = ip_addr_6Machines.json
-sampler = equi
+sampler = equi_check_history
 
 [SHARING]
 sharing_package = decentralizepy.sharing.DPSGDRWAsync
 sharing_class = DPSGDRWAsync
+rw_chance=1
diff --git a/src/decentralizepy/node/Node.py b/src/decentralizepy/node/Node.py
index 98e6664..8536b5c 100644
--- a/src/decentralizepy/node/Node.py
+++ b/src/decentralizepy/node/Node.py
@@ -499,6 +499,8 @@ class Node:
                     change = 5
                 if global_epoch == 119:
                     change = 10
+                if global_epoch == 499:
+                    change = 20
 
                 global_epoch += change
 
-- 
GitLab