Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
import json
import logging
import os
from pathlib import Path
import numpy as np
import pywt
import torch
from decentralizepy.sharing.DPSGDRWAsync import DPSGDRWAsync
from decentralizepy.utils import conditional_value, identity
def change_transformer_wavelet(x, wavelet, level):
"""
Transforms the model changes into wavelet frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
wavelet : str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
Returns
-------
x : torch.Tensor
Representation of the change int the wavelet domain
"""
coeff = pywt.wavedec(x, wavelet, level=level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
return torch.from_numpy(data.ravel())
class JwinsDPSGDAsync(DPSGDRWAsync):
"""
This class implements the vanilla version of partial model sharing.
"""
def __init__(
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
wavelet="haar",
level=4,
change_based_selection=True,
save_accumulated="",
accumulation=False,
accumulate_averaging_changes=False,
rw_chance=1,
rw_length=4,
comm_interval=0.5,
min_interval=0.001,
max_lag=2,
lower_bound = 0.1,
metro_hastings=True,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
wavelet: str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
change_based_selection : bool
use frequency change to select topk frequencies
save_accumulated : bool
True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
save_accumulated : bool
True if accumulated weight change should be written to file. In case of accumulation the accumulated change
is stored. If a change_transformer is used then the transformed change is stored.
change_transformer : (x: Tensor) -> Tensor
A function that transforms the model change into other domains. Default: identity function
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
self.wavelet = wavelet
self.level = level
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, comm_interval, min_interval, max_lag
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.accumulation = accumulation
self.save_accumulated = conditional_value(save_accumulated, "", False)
self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level)
self.accumulate_averaging_changes = accumulate_averaging_changes
# getting the initial model
self.shapes = []
self.lens = []
with torch.no_grad():
tensors_to_cat = []
for _, v in self.model.state_dict().items():
self.shapes.append(v.shape)
t = v.flatten()
self.lens.append(t.shape[0])
tensors_to_cat.append(t)
self.init_model = torch.cat(tensors_to_cat, dim=0)
if self.accumulation:
self.model.accumulated_changes = torch.zeros_like(
self.change_transformer(self.init_model)
)
self.prev = self.init_model
self.number_of_params = self.init_model.shape[0]
if self.save_accumulated:
self.model_change_path = os.path.join(
self.log_dir, "model_change/{}".format(self.rank)
)
Path(self.model_change_path).mkdir(parents=True, exist_ok=True)
self.model_val_path = os.path.join(
self.log_dir, "model_val/{}".format(self.rank)
)
Path(self.model_val_path).mkdir(parents=True, exist_ok=True)
# Only save for 2 procs: Save space
if self.save_shared and not (rank == 0 or rank == 1):
self.save_shared = False
if self.save_shared:
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
self.model.shared_parameters_counter = torch.zeros(
self.change_transformer(self.init_model).shape[0], dtype=torch.int32
)
self.change_based_selection = change_based_selection
# Do a dummy transform to get the shape and coefficents slices
coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
data, coeff_slices = pywt.coeffs_to_array(coeff)
self.wt_shape = data.shape
self.coeff_slices = coeff_slices
self.lower_bound = lower_bound
self.metro_hastings = metro_hastings
if self.lower_bound > 0:
self.start_lower_bounding_at = 1 / self.lower_bound
def apply_wavelet(self):
"""
Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected wavelet coefficients, b: Their indices.
"""
data = self.pre_share_model_transformed
if self.change_based_selection:
diff = self.model.model_change
_, index = torch.topk(
diff.abs(),
round(self.alpha * len(diff)),
dim=0,
sorted=False,
)
else:
_, index = torch.topk(
data.abs(),
round(self.alpha * len(data)),
dim=0,
sorted=False,
)
ind, _ = torch.sort(index)
if self.lower_bound == 0.0:
return data[ind].clone().detach(), ind
if self.communication_round > self.start_lower_bounding_at:
# because the superclass increases it where it is inconvenient for this subclass
currently_shared = self.model.shared_parameters_counter.clone().detach()
currently_shared[ind] += 1
ind_small = (
currently_shared < self.communication_round * self.lower_bound
).nonzero(as_tuple=True)[0]
ind_small_unique = np.setdiff1d(
ind_small.numpy(), ind.numpy(), assume_unique=True
)
take_max = round(self.lower_bound * self.alpha * data.shape[0])
logging.info(
"lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max
)
if take_max > ind_small_unique.shape[0]:
take_max = ind_small_unique.shape[0]
to_take = torch.rand(ind_small_unique.shape[0])
_, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False)
ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take]
logging.info("lower bounding: %i %i", len(ind), len(ind_bound))
# val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
ind = torch.cat([ind, ind_bound])
index, _ = torch.sort(ind)
return data[index].clone().detach(), index
def serialized_model(self):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
m = dict()
if self.alpha >= self.metadata_cap: # Share fully
data = self.pre_share_model_transformed
m["params"] = data.numpy()
if self.model.accumulated_changes is not None:
self.model.accumulated_changes = torch.zeros_like(
self.model.accumulated_changes
)
return m
with torch.no_grad():
topk, self.G_topk = self.apply_wavelet()
self.model.shared_parameters_counter[self.G_topk] += 1
self.model.rewind_accumulation(self.G_topk)
if self.save_shared:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = self.G_topk.tolist()
with open(
os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(shared_params, of)
if not self.dict_ordered:
raise NotImplementedError
m["alpha"] = self.alpha
m["indices"] = self.G_topk.numpy().astype(np.int32)
m["params"] = topk.numpy()
m["send_partial"] = True
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Generated dictionary to send")
logging.info("Converted dictionary to pickle")
return m
def deserialized_model(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
ret = dict()
if "send_partial" not in m:
params = m["params"]
params_tensor = torch.tensor(params)
ret["params"] = params_tensor
return ret
with torch.no_grad():
if not self.dict_ordered:
raise NotImplementedError
params_tensor = torch.tensor(m["params"])
indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
ret = dict()
ret["indices"] = indices_tensor
ret["params"] = params_tensor
ret["send_partial"] = True
return ret
def _averaging(self):
"""
Averages the received model with the local model
"""
with torch.no_grad():
total = None
weight_total = 0
wt_params = self.pre_share_model_transformed
if not self.metro_hastings:
weight_vector = torch.ones_like(wt_params)
datas = []
batch = self._preprocessing_received_models()
for n, vals in batch.items():
if len(vals) > 1:
data = None
degree = 0
# this should no longer happen, unless we get two rw from the same originator
logging.info("averaging double messages for %i", n)
for val in vals:
degree_sub, iteration, data_sub = val
if data is None:
data = data_sub
degree = degree
else:
for key, weight_val in data_sub.items():
data[key] += weight_val
degree = max(degree, degree_sub)
for key, weight_val in data.items():
data[key] /= len(vals)
else:
degree, iteration, data = vals[0]
#degree, iteration, data = self.peer_deques[n].popleft()
logging.debug(
"Averaging model from neighbor {} of iteration {}".format(
n, iteration
)
)
#data = self.deserialized_model(data)
params = data["params"]
if "indices" in data:
indices = data["indices"]
if not self.metro_hastings:
weight_vector[indices] += 1
topkwf = torch.zeros_like(wt_params)
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
datas.append(topkwf)
else:
# use local data to complement
topkwf = wt_params.clone().detach()
topkwf[indices] = params
topkwf = topkwf.reshape(self.wt_shape)
else:
topkwf = params.reshape(self.wt_shape)
if not self.metro_hastings:
weight_vector += 1
datas.append(topkwf)
if self.metro_hastings:
weight = 1 / (max(len(self.peer_deques), degree) + 1) # Metro-Hastings
weight_total += weight
if total is None:
total = weight * topkwf
else:
total += weight * topkwf
if not self.metro_hastings:
weight_vector = 1.0 / weight_vector
# speed up by exploiting sparsity
total = wt_params * weight_vector
for d in datas:
total += d * weight_vector
else:
# Metro-Hastings
total += (1 - weight_total) * wt_params
avg_wf_params = pywt.array_to_coeffs(
total.numpy(), self.coeff_slices, output_format="wavedec"
)
reverse_total = torch.from_numpy(
pywt.waverec(avg_wf_params, wavelet=self.wavelet)
)
start_index = 0
std_dict = {}
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
std_dict[key] = reverse_total[start_index:end_index].reshape(
self.shapes[i]
)
start_index = end_index
self.model.load_state_dict(std_dict)
to_cat = []
for _, v in self.model.state_dict().items():
vf = v.clone().detach().flatten()
to_cat.append(vf)
self.init_model = torch.cat(to_cat)
self._transformed = self.change_transformer(self.init_model)
def _pre_step(self):
"""
Called at the beginning of step.
"""
logging.info("PartialModel _pre_step")
with torch.no_grad():
tensors_to_cat = [
v.data.flatten() for _, v in self.model.state_dict().items()
]
self.pre_share_model = torch.cat(tensors_to_cat, dim=0)
# Would only need one of the transforms
self.pre_share_model_transformed = self.change_transformer(
self.pre_share_model
)
def _post_step(self):
"""
Called at the end of step.
"""
change = self.change_transformer(self.pre_share_model - self.init_model) # self.init_model is set in _averaging
if self.accumulation:
# Need to accumulate in _pre_step as the accumulation gets rewind during the step
self.model.accumulated_changes += change
change = self.model.accumulated_changes.clone().detach()
# stores change of the model due to training, change due to averaging is not accounted
self.model.model_change = change
def _send_rw(self):
def send():
# will have to send the data twice to make the code simpler (for the beginning)
if self.alpha >= self.metadata_cap:
rw_data = {
"params": self.init_model.numpy(),
"rw": True,
"degree": self.number_of_neighbors,
"iteration": self.communication_round,
"visited": [self.uid],
"fuel": self.rw_length - 1,
}
else:
rw_data = {
"params": self._transformed[self.G_topk].numpy(),
"indices": self.G_topk.numpy().astype(np.int32),
"rw": True,
"degree": self.number_of_neighbors,
"iteration": self.communication_round,
"visited": [self.uid],
"fuel": self.rw_length - 1,
"send_partial": True,
}
logging.info("new rw message")
self.communication.send(None, rw_data)
rw_chance = self.rw_chance
self.serialized_model() # dummy call to get self.G_topK
while rw_chance >= 1.0:
# TODO: make sure they are not sent to the same neighbour
send()
rw_chance -= 1
rw_now = torch.rand(size=(1,), generator=self.random_generator).item()
if rw_now < rw_chance:
send()
def save_vector(self, v, s):
"""
Saves the given vector to the file.
Parameters
----------
v : torch.tensor
The torch tensor to write to file
s : str
Path to folder to write to
"""
output_dict = dict()
output_dict["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v1 in self.model.state_dict().items():
shapes[k] = list(v1.shape)
output_dict["shapes"] = shapes
output_dict["tensor"] = v.tolist()
with open(
os.path.join(
s,
"{}.json".format(self.communication_round + 1),
),
"w",
) as of:
json.dump(output_dict, of)
def save_change(self):
"""
Saves the change and the gradient values for every iteration
"""
self.save_vector(self.model.model_change, self.model_change_path)