Skip to content
Snippets Groups Projects
Commit 76f1fbb0 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

integrating new sharing methods

parent 98e938bc
No related branches found
No related tags found
No related merge requests found
Showing
with 2280 additions and 16 deletions
96
0 24
0 1
0 26
0 95
1 2
1 0
1 82
1 83
2 33
2 90
2 3
2 1
3 2
3 4
3 14
3 79
4 3
4 12
4 5
4 86
5 64
5 42
5 4
5 6
6 9
6 5
6 62
6 7
7 24
7 8
7 45
7 6
8 81
8 17
8 9
8 7
9 8
9 10
9 53
9 6
10 9
10 11
10 29
10 31
11 80
11 10
11 36
11 12
12 11
12 4
12 13
12 70
13 12
13 53
13 30
13 14
14 3
14 15
14 13
14 47
15 16
15 26
15 14
16 41
16 17
16 15
17 8
17 16
17 18
17 83
18 17
18 19
18 95
18 63
19 82
19 18
19 20
19 22
20 19
20 59
20 21
20 22
21 72
21 58
21 20
21 22
22 19
22 20
22 21
22 23
23 24
23 65
23 85
23 22
24 0
24 25
24 23
24 7
25 32
25 24
25 26
25 38
26 0
26 25
26 27
26 15
27 32
27 26
27 28
27 63
28 27
28 92
28 29
28 39
29 10
29 52
29 28
29 30
30 66
30 29
30 13
30 31
31 32
31 10
31 36
31 30
32 25
32 27
32 31
32 33
33 32
33 2
33 84
33 34
34 33
34 50
34 35
34 93
35 57
35 34
35 43
35 36
36 35
36 11
36 37
36 31
37 88
37 36
37 38
37 79
38 25
38 37
38 39
38 49
39 40
39 28
39 77
39 38
40 41
40 91
40 39
40 87
41 16
41 40
41 42
41 51
42 41
42 43
42 5
43 42
43 35
43 44
44 72
44 43
44 75
44 45
45 67
45 44
45 46
45 7
46 76
46 45
46 54
46 47
47 48
47 65
47 14
47 46
48 56
48 49
48 61
48 47
49 48
49 50
49 38
49 71
50 49
50 34
50 51
50 93
51 41
51 50
51 52
51 95
52 51
52 74
52 53
52 29
53 9
53 52
53 13
53 54
54 75
54 53
54 46
54 55
55 56
55 69
55 85
55 54
56 48
56 57
56 69
56 55
57 56
57 89
57 58
57 35
58 57
58 59
58 21
58 86
59 73
59 58
59 20
59 60
60 62
60 59
60 61
60 78
61 48
61 62
61 60
61 94
62 60
62 61
62 6
62 63
63 64
63 18
63 27
63 62
64 65
64 84
64 5
64 63
65 64
65 66
65 23
65 47
66 65
66 89
66 67
66 30
67 80
67 66
67 68
67 45
68 67
68 92
68 69
68 94
69 56
69 68
69 70
69 55
70 90
70 12
70 69
70 71
71 72
71 49
71 70
71 87
72 73
72 44
72 21
72 71
73 72
73 91
73 59
73 74
74 73
74 75
74 52
74 76
75 74
75 44
75 54
75 76
76 74
76 75
76 77
76 46
77 81
77 76
77 78
77 39
78 88
78 60
78 77
78 79
79 80
79 3
79 37
79 78
80 81
80 67
80 11
80 79
81 8
81 82
81 80
81 77
82 81
82 1
82 83
82 19
83 1
83 82
83 84
83 17
84 64
84 33
84 83
84 85
85 84
85 55
85 86
85 23
86 58
86 4
86 85
86 87
87 40
87 88
87 86
87 71
88 89
88 37
88 78
88 87
89 88
89 57
89 66
89 90
90 89
90 2
90 91
90 70
91 40
91 73
91 90
91 92
92 93
92 91
92 68
92 28
93 50
93 34
93 94
93 92
94 93
94 68
94 61
94 95
95 0
95 18
95 51
95 94
......@@ -61,14 +61,20 @@ def plot_results(path):
plt.figure(1)
means, stdevs, mins, maxs = get_stats([x["train_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Training Loss", folder, "upper right")
with open(os.path.join(path, "train_loss_" + folder + ".json"), "w") as f:
json.dump({"mean": means, "std": stdevs}, f)
# Plot Testing loss
plt.figure(2)
means, stdevs, mins, maxs = get_stats([x["test_loss"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Loss", folder, "upper right")
with open(os.path.join(path, "test_loss_" + folder + ".json"), "w") as f:
json.dump({"mean": means, "std": stdevs}, f)
# Plot Testing Accuracy
plt.figure(3)
means, stdevs, mins, maxs = get_stats([x["test_acc"] for x in results])
plot(means, stdevs, mins, maxs, "Testing Accuracy", folder, "lower right")
with open(os.path.join(path, "test_acc_" + folder + ".json"), "w") as f:
json.dump({"mean": means, "std": stdevs}, f)
plt.figure(6)
means, stdevs, mins, maxs = get_stats([x["grad_std"] for x in results])
plot(
......
......@@ -4,29 +4,21 @@ decpy_path=~/Gitlab/decentralizepy/eval
cd $decpy_path
env_python=~/miniconda3/envs/decpy/bin/python3
graph=96_nodes_random1.edges
graph=96_regular.edges
original_config=epoch_configs/config_celeba.ini
config_file=/tmp/config.ini
procs_per_machine=16
machines=6
iterations=76
test_after=2
iterations=200
test_after=10
eval_file=testing.py
log_level=INFO
log_dir_base=/mnt/nfs/some_user/logs/test
m=`cat $(grep addresses_filepath $original_config | awk '{print $3}') | grep $(/sbin/ifconfig ens785 | grep 'inet ' | awk '{print $2}') | cut -d'"' -f2`
cp $original_config $config_file
echo "alpha = 0.75" >> $config_file
$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
cp $original_config $config_file
echo "alpha = 0.50" >> $config_file
$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
log_dir=$log_dir_base$m
cp $original_config $config_file
echo "alpha = 0.10" >> $config_file
$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
config_file=epoch_configs/config_celeba_100.ini
$env_python $eval_file -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $original_config -ll $log_level
# echo "alpha = 0.10" >> $config_file
$env_python $eval_file -ld $log_dir -mid $m -ps $procs_per_machine -ms $machines -is $iterations -gf $graph -ta $test_after -cf $config_file -ll $log_level
\ No newline at end of file
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.FrequencyAccumulator
training_class = FrequencyAccumulator
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
accumulation = True
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.FFT
sharing_class = FFT
alpha = 0.1
change_based_selection = True
accumulation = True
\ No newline at end of file
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
rounds = 10
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.Sharing
sharing_class = Sharing
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.SubSampling
sharing_class = SubSampling
alpha = 0.1
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.ModelChangeAccumulator
training_class = ModelChangeAccumulator
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
accumulation = True
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.TopK
sharing_class = TopK
alpha = 0.1
accumulation = True
\ No newline at end of file
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.TopKParams
sharing_class = TopKParams
alpha = 0.1
[DATASET]
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
train_dir = /mnt/nfs/shared/leaf/data/femnist/per_user_data/train
test_dir = /mnt/nfs/shared/leaf/data/femnist/data/test
; python list of fractions below
sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.001
# There are 734463 femnist samples
[TRAIN_PARAMS]
training_package = decentralizepy.training.FrequencyWaveletAccumulator
training_class = FrequencyWaveletAccumulator
rounds = 47
full_epochs = False
batch_size = 16
shuffle = True
loss_package = torch.nn
loss_class = CrossEntropyLoss
wavelet=sym2
level= None
accumulation = True
[COMMUNICATION]
comm_package = decentralizepy.communication.TCP
comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.Wavelet
sharing_class = Wavelet
change_based_selection = True
alpha = 0.1
wavelet=sym2
level= None
accumulation = True
......@@ -24,7 +24,8 @@ def read_ini(file_path):
if __name__ == "__main__":
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
# prevents accidental log overwrites
Path(args.log_dir).mkdir(parents=True, exist_ok=False)
log_level = {
"INFO": logging.INFO,
......
......@@ -42,6 +42,7 @@ install_requires =
pillow
smallworld
localconfig
PyWavelets
include_package_data = True
python_requires = >=3.6
[options.packages.find]
......
......@@ -17,6 +17,9 @@ class Model(nn.Module):
self.accumulated_gradients = []
self._param_count_ot = None
self._param_count_total = None
self.accumulated_frequency = None
self.prev_model_params = None
self.prev = None
def count_params(self, only_trainable=False):
"""
......
......@@ -92,6 +92,8 @@ class Node:
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
iterations : int
Number of iterations (communication steps) ) for which the model should be trained
log_dir : str
Logging directory
reset_optimizer : int
......@@ -278,6 +280,8 @@ class Node:
The object containing the global graph
config : dict
A dictionary of configurations.
iterations : int
Number of iterations (communication steps) ) for which the model should be trained
log_dir : str
Logging directory
log_level : logging.Level
......@@ -443,6 +447,8 @@ class Node:
training_class = Training
epochs_per_round = 25
batch_size = 64
iterations : int
Number of iterations (communication steps) ) for which the model should be trained
log_dir : str
Logging directory
log_level : logging.Level
......
import base64
import json
import logging
import os
import pickle
from pathlib import Path
from time import time
import torch
import torch.fft as fft
from decentralizepy.sharing.Sharing import Sharing
class FFT(Sharing):
"""
This class implements the fft version of model sharing
It is based on PartialModel.py
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
pickle=True,
change_based_selection=True,
accumulation=True,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
pickle : bool
use pickle to serialize the model parameters
change_based_selection : bool
use frequency change to select topk frequencies
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
self.pickle = pickle
logging.info("subsampling pickling=" + str(pickle))
if self.save_shared:
# Only save for 2 procs: Save space
if rank != 0 or rank != 1:
self.save_shared = False
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.change_based_selection = change_based_selection
self.accumulation = accumulation
def apply_fft(self):
"""
Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected fft frequencies (complex numbers), b: Their indices.
"""
logging.info("Returning fft compressed model weights")
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
concated = torch.cat(tensors_to_cat, dim=0)
if self.change_based_selection:
flat_fft = fft.rfft(concated)
if self.accumulation:
logging.info(
"fft topk extract frequencies based on accumulated model frequency change"
)
diff = self.model.accumulated_frequency + (flat_fft - self.model.prev)
else:
diff = flat_fft - self.model.accumulated_frequency
_, index = torch.topk(
diff.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
)
else:
flat_fft = fft.rfft(concated)
_, index = torch.topk(
flat_fft.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
)
if self.accumulation:
self.model.accumulated_frequency[index] = 0.0
return flat_fft[index], index
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
topk, indices = self.apply_fft()
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = indices.tolist() # is slow
shared_params["alpha"] = self.alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
m = dict()
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["params"] = topk.numpy()
m["indices"] = indices.numpy()
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["alpha"])
)
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Parameters
----------
m : dict
json dict received
Returns
-------
state_dict
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
indices = m["indices"]
alpha = m["alpha"]
params = m["params"]
params_tensor = torch.tensor(params)
indices_tensor = torch.tensor(indices)
ret = dict()
ret["indices"] = indices_tensor
ret["params"] = params_tensor
return ret
def step(self):
"""
Perform a sharing step. Implements D-PSGD.
"""
t_start = time()
data = self.serialized_model()
t_post_serialize = time()
my_uid = self.mapping.get_uid(self.rank, self.machine_id)
all_neighbors = self.graph.neighbors(my_uid)
iter_neighbors = self.get_neighbors(all_neighbors)
data["degree"] = len(all_neighbors)
data["iteration"] = self.communication_round
for neighbor in iter_neighbors:
self.communication.send(neighbor, data)
t_post_send = time()
logging.info("Waiting for messages from neighbors")
while not self.received_from_all():
sender, data = self.communication.receive()
logging.debug("Received model from {}".format(sender))
degree = data["degree"]
iteration = data["iteration"]
del data["degree"]
del data["iteration"]
self.peer_deques[sender].append((degree, iteration, data))
logging.info(
"Deserialized received model from {} of iteration {}".format(
sender, iteration
)
)
t_post_recv = time()
logging.info("Starting model averaging after receiving from all neighbors")
total = None
weight_total = 0
# FFT of this model
shapes = []
lens = []
tensors_to_cat = []
for _, v in self.model.state_dict().items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
concated = torch.cat(tensors_to_cat, dim=0)
flat_fft = fft.rfft(concated)
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(n, iteration)
)
data = self.deserialized_model(data)
params = data["params"]
indices = data["indices"]
# use local data to complement
topkf = flat_fft.clone().detach()
topkf[indices] = params
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
if total is None:
total = weight * topkf
else:
total += weight * topkf
# Metro-Hastings
total += (1 - weight_total) * flat_fft
reverse_total = fft.irfft(total)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
start_index = end_index
self.model.load_state_dict(std_dict)
logging.info("Model averaging complete")
self.communication_round += 1
t_end = time()
logging.info(
"Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
t_post_serialize - t_start,
t_post_send - t_post_serialize,
t_post_recv - t_post_send,
t_end - t_post_recv,
t_end - t_start,
)
import base64
import json
import logging
import os
import pickle
from pathlib import Path
import torch
from decentralizepy.sharing.Sharing import Sharing
class SubSampling(Sharing):
"""
This class implements the subsampling version of model sharing
It is based on PartialModel.py
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
pickle=True,
layerwise=False,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
pickle : bool
use pickle to serialize the model parameters
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
# self.random_seed_generator = torch.Generator()
# # Will use the random device if supported by CPU, else uses the system time
# # In the latter case we could get duplicate seeds on some of the machines
# self.random_seed_generator.seed()
self.random_generator = torch.Generator()
# Will use the random device if supported by CPU, else uses the system time
# In the latter case we could get duplicate seeds on some of the machines
self.random_generator.seed()
self.seed = self.random_generator.initial_seed()
self.pickle = pickle
self.layerwise = layerwise
logging.info("subsampling pickling=" + str(pickle))
if self.save_shared:
# Only save for 2 procs: Save space
if rank != 0 or rank != 1:
self.save_shared = False
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def apply_subsampling(self):
"""
Creates a random binary mask that is used to subsample the parameters that will be shared
Returns
-------
tuple
(a,b,c). a: the selected parameters as flat vector, b: the random seed used to crate the binary mask
c: the alpha
"""
logging.info("Returning subsampling gradients")
if not self.layerwise:
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
concated = torch.cat(tensors_to_cat, dim=0)
curr_seed = self.seed + self.communication_round # is increased in step
self.random_generator.manual_seed(curr_seed)
# logging.debug("Subsampling seed for uid = " + str(self.uid) + " is: " + str(curr_seed))
# Or we could use torch.bernoulli
binary_mask = (
torch.rand(
size=(concated.size(dim=0),), generator=self.random_generator
)
<= self.alpha
)
subsample = concated[binary_mask]
# logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
return (subsample, curr_seed, self.alpha)
else:
values_list = []
offsets = [0]
off = 0
curr_seed = self.seed + self.communication_round # is increased in step
self.random_generator.manual_seed(curr_seed)
for _, v in self.model.state_dict().items():
flat = v.flatten()
binary_mask = (
torch.rand(
size=(flat.size(dim=0),), generator=self.random_generator
)
<= self.alpha
)
selected = flat[binary_mask]
values_list.append(selected)
off += selected.size(dim=0)
offsets.append(off)
subsample = torch.cat(values_list, dim=0)
return (subsample, curr_seed, self.alpha)
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
subsample, seed, alpha = self.apply_subsampling()
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
# TODO: should store the shared indices and not the value
# shared_params[self.communication_round] = subsample.tolist() # is slow
shared_params["seed"] = seed
shared_params["alpha"] = alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
m = dict()
if not self.dict_ordered:
raise NotImplementedError
m["seed"] = seed
m["alpha"] = alpha
m["params"] = subsample.numpy()
# logging.info("Converted dictionary to json")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["seed"])) + len(
self.communication.encrypt(m["alpha"])
)
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Parameters
----------
m : dict
json dict received
Returns
-------
state_dict
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
seed = m["seed"]
alpha = m["alpha"]
params = m["params"]
random_generator = (
torch.Generator()
) # new generator, such that we do not overwrite the other one
random_generator.manual_seed(seed)
shapes = []
lens = []
tensors_to_cat = []
binary_submasks = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
if self.layerwise:
binary_mask = (
torch.rand(size=(t.size(dim=0),), generator=random_generator)
<= alpha
)
binary_submasks.append(binary_mask)
T = torch.cat(tensors_to_cat, dim=0)
params_tensor = torch.from_numpy(params)
if not self.layerwise:
binary_mask = (
torch.rand(size=(T.size(dim=0),), generator=random_generator)
<= alpha
)
else:
binary_mask = torch.cat(binary_submasks, dim=0)
logging.debug("Original tensor: {}".format(T[binary_mask]))
T[binary_mask] = params_tensor
logging.debug("Final tensor: {}".format(T[binary_mask]))
start_index = 0
for i, key in enumerate(state_dict):
end_index = start_index + lens[i]
state_dict[key] = T[start_index:end_index].reshape(shapes[i])
start_index = end_index
return state_dict
import json
import logging
import os
from pathlib import Path
import torch
from decentralizepy.sharing.Sharing import Sharing
class TopK(Sharing):
"""
This class implements topk selection of model parameters based on the model change since the beginning of the
communication step: --> Use ModelChangeAccumulator
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
accumulation=False,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
self.accumulation = accumulation
if self.save_shared:
# Only save for 2 procs: Save space
if rank != 0 or rank != 1:
self.save_shared = False
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulationd.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
concated = torch.cat(tensors_to_cat, dim=0)
if self.accumulation:
logging.info(
"TopK extract gradients based on accumulated model parameter change"
)
diff = self.model.prev_model_params + (concated - self.model.prev)
else:
diff = concated - self.model.prev_model_params
G_topk = torch.abs(diff)
std, mean = torch.std_mean(G_topk, unbiased=False)
self.std = std.item()
self.mean = mean.item()
value, ind = torch.topk(
G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
)
# only needed when ModelChangeAccumulator.accumulation = True
# does not cause problems otherwise
if self.accumulation:
self.model.prev_model_params[ind] = 0.0 # torch.zeros((len(G_topk),))
return value, ind
def serialized_model(self):
"""
Convert model to a dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to a dict
"""
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
logging.info("Extracting topk params")
tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
T = torch.cat(tensors_to_cat, dim=0)
T_topk = T[G_topk]
logging.info("Generating dictionary to send")
m = dict()
if not self.dict_ordered:
raise NotImplementedError
m["indices"] = G_topk.numpy()
m["params"] = T_topk.numpy()
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Generated dictionary to send")
logging.info("Converted dictionary to pickle")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"]))
return m
def deserialized_model(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
dict received
Returns
-------
state_dict
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(m["indices"])
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor]))
start_index = 0
for i, key in enumerate(state_dict):
end_index = start_index + lens[i]
state_dict[key] = T[start_index:end_index].reshape(shapes[i])
start_index = end_index
return state_dict
import json
import logging
import os
from pathlib import Path
import torch
from decentralizepy.sharing.Sharing import Sharing
class TopKParams(Sharing):
"""
This class implements the vanilla version of partial model sharing.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
if self.save_shared:
# Only save for 2 procs: Save space
if rank != 0 or rank != 1:
self.save_shared = False
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def extract_top_params(self):
"""
Extract the indices and values of the topK params layerwise.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b,c). a: The topK params, b: Their indices, c: The offsets
"""
logging.info("Returning TopKParams gradients")
values_list = []
index_list = []
offsets = [0]
off = 0
for _, v in self.model.state_dict().items():
flat = v.flatten()
values, index = torch.topk(
flat.abs(), round(self.alpha * flat.size(dim=0)), dim=0, sorted=False
)
values_list.append(flat[index])
index_list.append(index)
off += values.size(dim=0)
offsets.append(off)
cat_values = torch.cat(values_list, dim=0)
cat_index = torch.cat(index_list, dim=0)
# logging.debug("Subsampling vector is of size: " + str(subsample.size(dim = 0)))
return (cat_values, cat_index, offsets)
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
values, index, offsets = self.extract_top_params()
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = index.tolist()
# TODO: store offsets
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
logging.info("Extracting topk params")
logging.info("Generating dictionary to send")
m = dict()
if not self.dict_ordered:
raise NotImplementedError
m["indices"] = index.numpy()
m["params"] = values.numpy()
m["offsets"] = offsets
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Generated dictionary to send")
# for key in m:
# m[key] = json.dumps(m[key])
logging.info("Converted dictionary to json")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["offsets"])
)
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Parameters
----------
m : dict
json dict received
Returns
-------
state_dict
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
offsets = m["offsets"]
params = torch.tensor(m["params"])
indices = torch.tensor(m["indices"])
for i, (_, v) in enumerate(state_dict.items()):
shapes.append(v.shape)
t = v.flatten().clone().detach() # it is not always copied
lens.append(t.shape[0])
index = indices[offsets[i] : offsets[i + 1]]
t[index] = params[offsets[i] : offsets[i + 1]]
tensors_to_cat.append(t)
start_index = 0
for i, key in enumerate(state_dict):
end_index = start_index + lens[i]
state_dict[key] = tensors_to_cat[i].reshape(shapes[i])
start_index = end_index
return state_dict
import base64
import json
import logging
import os
import pickle
from pathlib import Path
from time import time
import pywt
import torch
from decentralizepy.sharing.Sharing import Sharing
class Wavelet(Sharing):
"""
This class implements the wavelet version of model sharing
It is based on PartialModel.py
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
pickle=True,
wavelet="haar",
level=4,
change_based_selection=True,
accumulation=False,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
pickle : bool
use pickle to serialize the model parameters
wavelet: str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
change_based_selection : bool
use frequency change to select topk frequencies
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
"""
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
self.pickle = pickle
self.wavelet = wavelet
self.level = level
self.accumulation = accumulation
logging.info("subsampling pickling=" + str(pickle))
if self.save_shared:
# Only save for 2 procs: Save space
if rank != 0 or rank != 1:
self.save_shared = False
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.change_based_selection = change_based_selection
def apply_wavelet(self):
"""
Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected wavelet coefficients, b: Their indices.
"""
logging.info("Returning dwt compressed model weights")
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
concated = torch.cat(tensors_to_cat, dim=0)
if self.change_based_selection:
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(
coeff
) # coeff_slices will be reproduced on the receiver
data = data.ravel()
if self.accumulation:
logging.info(
"wavelet topk extract frequencies based on accumulated model frequency change"
)
diff = self.model.accumulated_frequency + (data - self.model.prev)
else:
diff = data - self.model.accumulated_frequency
_, index = torch.topk(
torch.from_numpy(diff).abs(),
round(self.alpha * len(data)),
dim=0,
sorted=False,
)
else:
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(
coeff
) # coeff_slices will be reproduced on the receiver
data = data.ravel()
_, index = torch.topk(
torch.from_numpy(data).abs(),
round(self.alpha * len(data)),
dim=0,
sorted=False,
)
if self.accumulation:
self.model.accumulated_frequency[index] = 0.0
return torch.from_numpy(data[index]), index
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
if self.alpha > self.metadata_cap: # Share fully
return super().serialized_model()
with torch.no_grad():
topk, indices = self.apply_wavelet()
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = indices.tolist() # is slow
shared_params["alpha"] = self.alpha
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
m = dict()
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["params"] = topk.numpy()
m["indices"] = indices.numpy()
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["alpha"])
)
return m
def deserialized_model(self, m):
"""
Convert received json dict to state_dict.
Parameters
----------
m : dict
json dict received
Returns
-------
state_dict
state_dict of received
"""
if self.alpha > self.metadata_cap: # Share fully
return super().deserialized_model(m)
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
indices = m["indices"]
alpha = m["alpha"]
params = m["params"]
params_tensor = torch.tensor(params)
indices_tensor = torch.tensor(indices)
ret = dict()
ret["indices"] = indices_tensor
ret["params"] = params_tensor
return ret
def step(self):
"""
Perform a sharing step. Implements D-PSGD.
"""
t_start = time()
data = self.serialized_model()
t_post_serialize = time()
my_uid = self.mapping.get_uid(self.rank, self.machine_id)
all_neighbors = self.graph.neighbors(my_uid)
iter_neighbors = self.get_neighbors(all_neighbors)
data["degree"] = len(all_neighbors)
data["iteration"] = self.communication_round
for neighbor in iter_neighbors:
self.communication.send(neighbor, data)
t_post_send = time()
logging.info("Waiting for messages from neighbors")
while not self.received_from_all():
sender, data = self.communication.receive()
logging.debug("Received model from {}".format(sender))
degree = data["degree"]
iteration = data["iteration"]
del data["degree"]
del data["iteration"]
self.peer_deques[sender].append((degree, iteration, data))
logging.info(
"Deserialized received model from {} of iteration {}".format(
sender, iteration
)
)
t_post_recv = time()
logging.info("Starting model averaging after receiving from all neighbors")
total = None
weight_total = 0
# FFT of this model
shapes = []
lens = []
tensors_to_cat = []
# TODO: should we detach
for _, v in self.model.state_dict().items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
wt_params, coeff_slices = pywt.coeffs_to_array(
coeff
) # coeff_slices will be reproduced on the receiver
shape = wt_params.shape
wt_params = wt_params.ravel()
for i, n in enumerate(self.peer_deques):
degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(n, iteration)
)
data = self.deserialized_model(data)
params = data["params"]
indices = data["indices"]
# use local data to complement
topkwf = wt_params.copy() # .clone().detach()
topkwf[indices] = params
topkwf = torch.from_numpy(topkwf.reshape(shape))
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
# Metro-Hastings
total += (1 - weight_total) * wt_params
avg_wf_params = pywt.array_to_coeffs(
total, coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
start_index = end_index
self.model.load_state_dict(std_dict)
logging.info("Model averaging complete")
self.communication_round += 1
t_end = time()
logging.info(
"Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
t_post_serialize - t_start,
t_post_send - t_post_serialize,
t_post_recv - t_post_send,
t_end - t_post_recv,
t_end - t_start,
)
import logging
import torch
from torch import fft
from decentralizepy.training.Training import Training
class FrequencyAccumulator(Training):
"""
This class implements the training module which also accumulates the fft frequency at the beginning of steps a communication round.
"""
def __init__(
self,
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds="",
full_epochs="",
batch_size="",
shuffle="",
accumulation=True,
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
model : torch.nn.Module
Neural Network for training
optimizer : torch.optim
Optimizer to learn parameters
loss : function
Loss function
log_dir : str
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
shuffle : bool
True if the dataset should be shuffled before training.
accumulation : bool
True if the model change should be accumulated across communication steps
"""
super().__init__(
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds,
full_epochs,
batch_size,
shuffle,
)
self.accumulation = accumulation
def train(self, dataset):
"""
Does one training iteration.
If self.accumulation is True then it accumulates model fft frequency changes in model.accumulated_frequency.
Otherwise it stores the current fft frequency representation of the model in model.accumulated_frequency.
Parameters
----------
dataset : decentralizepy.datasets.Dataset
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
# this looks at the change from the last round averaging of the frequencies
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
concated = torch.cat(tensors_to_cat, dim=0)
flat_fft = fft.rfft(concated)
if self.accumulation:
if self.model.accumulated_frequency is None:
logging.info("Initialize fft frequency accumulation")
self.model.accumulated_frequency = torch.zeros_like(flat_fft)
self.model.prev = flat_fft
else:
logging.info("fft frequency accumulation step")
self.model.accumulated_frequency += flat_fft - self.model.prev
self.model.prev = flat_fft
else:
logging.info("fft frequency accumulation reset")
self.model.accumulated_frequency = flat_fft
super().train(dataset)
import logging
import numpy as np
import pywt
import torch
from decentralizepy.training.Training import Training
class FrequencyWaveletAccumulator(Training):
"""
This class implements the training module which also accumulates the wavelet frequency at the beginning of steps a communication round.
"""
def __init__(
self,
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds="",
full_epochs="",
batch_size="",
shuffle="",
wavelet="haar",
level=4,
accumulation=True,
):
"""
Constructor
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
model : torch.nn.Module
Neural Network for training
optimizer : torch.optim
Optimizer to learn parameters
loss : function
Loss function
log_dir : str
Directory to log the model change.
rounds : int, optional
Number of steps/epochs per training call
full_epochs: bool, optional
True if 1 round = 1 epoch. False if 1 round = 1 minibatch
batch_size : int, optional
Number of items to learn over, in one batch
shuffle : bool
True if the dataset should be shuffled before training.
accumulation : bool
True if the model change should be accumulated across communication steps
"""
super().__init__(
rank,
machine_id,
mapping,
model,
optimizer,
loss,
log_dir,
rounds,
full_epochs,
batch_size,
shuffle,
)
self.wavelet = wavelet
self.level = level
self.accumulation = accumulation
def train(self, dataset):
"""
Does one training iteration.
If self.accumulation is True then it accumulates model wavelet frequency changes in model.accumulated_frequency.
Otherwise it stores the current wavelet frequency representation of the model in model.accumulated_frequency.
Parameters
----------
dataset : decentralizepy.datasets.Dataset
The training dataset. Should implement get_trainset(batch_size, shuffle)
"""
# this looks at the change from the last round averaging of the frequencies
tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
concated = torch.cat(tensors_to_cat, dim=0)
coeff = pywt.wavedec(concated.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
data = data.ravel()
if self.accumulation:
if self.model.accumulated_frequency is None:
logging.info("Initialize wavelet frequency accumulation")
self.model.accumulated_frequency = np.zeros_like(
data
) # torch.zeros_like(data)
self.model.prev = data
else:
logging.info("wavelet frequency accumulation step")
self.model.accumulated_frequency += data - self.model.prev
self.model.prev = data
else:
logging.info("wavelet frequency accumulation reset")
self.model.accumulated_frequency = data
super().train(dataset)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment