Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
J
Jeffrey_Wigger_Master_Project
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
SaCS
Semester-Projects
spring22
Jeffrey_Wigger_Master_Project
Commits
f005b1e4
Commit
f005b1e4
authored
2 years ago
by
Jeffrey Wigger
Browse files
Options
Downloads
Patches
Plain Diff
WaveletBound.py
parent
d20cf02d
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/decentralizepy/sharing/WaveletBound.py
+432
-0
432 additions, 0 deletions
src/decentralizepy/sharing/WaveletBound.py
with
432 additions
and
0 deletions
src/decentralizepy/sharing/WaveletBound.py
0 → 100644
+
432
−
0
View file @
f005b1e4
import
json
import
logging
import
os
from
pathlib
import
Path
from
time
import
time
import
numpy
as
np
import
pywt
import
torch
from
decentralizepy.sharing.LowerBoundTopK
import
LowerBoundTopK
def
change_transformer_wavelet
(
x
,
wavelet
,
level
):
"""
Transforms the model changes into wavelet frequency domain
Parameters
----------
x : torch.Tensor
Model change in the space domain
wavelet : str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
Returns
-------
x : torch.Tensor
Representation of the change int the wavelet domain
"""
coeff
=
pywt
.
wavedec
(
x
,
wavelet
,
level
=
level
)
data
,
coeff_slices
=
pywt
.
coeffs_to_array
(
coeff
)
return
torch
.
from_numpy
(
data
.
ravel
())
class
WaveletBound
(
LowerBoundTopK
):
"""
This class implements the wavelet version of model sharing
It is based on PartialModel.py
"""
def
__init__
(
self
,
rank
,
machine_id
,
communication
,
mapping
,
graph
,
model
,
dataset
,
log_dir
,
alpha
=
1.0
,
dict_ordered
=
True
,
save_shared
=
False
,
metadata_cap
=
1.0
,
wavelet
=
"
haar
"
,
level
=
4
,
change_based_selection
=
True
,
save_accumulated
=
""
,
accumulation
=
False
,
accumulate_averaging_changes
=
False
,
lower_bound
=
0.1
,
metro_hastings
=
True
,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
alpha : float
Percentage of model to share
dict_ordered : bool
Specifies if the python dict maintains the order of insertion
save_shared : bool
Specifies if the indices of shared parameters should be logged
metadata_cap : float
Share full model when self.alpha > metadata_cap
wavelet: str
name of the wavelet to be used in gradient compression
level: int
name of the wavelet to be used in gradient compression
change_based_selection : bool
use frequency change to select topk frequencies
save_accumulated : bool
True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
the accumulated change is stored.
accumulation : bool
True if the the indices to share should be selected based on accumulated frequency change
accumulate_averaging_changes: bool
True if the accumulation should account the model change due to averaging
"""
self
.
wavelet
=
wavelet
self
.
level
=
level
super
().
__init__
(
rank
,
machine_id
,
communication
,
mapping
,
graph
,
model
,
dataset
,
log_dir
,
lower_bound
,
metro_hastings
,
alpha
=
alpha
,
dict_ordered
=
dict_ordered
,
save_shared
=
save_shared
,
metadata_cap
=
metadata_cap
,
accumulation
=
accumulation
,
save_accumulated
=
save_accumulated
,
change_transformer
=
lambda
x
:
change_transformer_wavelet
(
x
,
wavelet
,
level
),
accumulate_averaging_changes
=
accumulate_averaging_changes
,
)
self
.
change_based_selection
=
change_based_selection
# Do a dummy transform to get the shape and coefficents slices
coeff
=
pywt
.
wavedec
(
self
.
init_model
.
numpy
(),
self
.
wavelet
,
level
=
self
.
level
)
data
,
coeff_slices
=
pywt
.
coeffs_to_array
(
coeff
)
self
.
wt_shape
=
data
.
shape
self
.
coeff_slices
=
coeff_slices
def
apply_wavelet
(
self
):
"""
Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
based on the undergone change during the current training step
Returns
-------
tuple
(a,b). a: selected wavelet coefficients, b: Their indices.
"""
logging
.
info
(
"
Returning wavelet compressed model weights
"
)
data
=
self
.
pre_share_model_transformed
if
self
.
change_based_selection
:
diff
=
self
.
model
.
model_change
_
,
index
=
torch
.
topk
(
diff
.
abs
(),
round
(
self
.
alpha
*
len
(
diff
)),
dim
=
0
,
sorted
=
False
,
)
else
:
_
,
index
=
torch
.
topk
(
data
.
abs
(),
round
(
self
.
alpha
*
len
(
data
)),
dim
=
0
,
sorted
=
False
,
)
index
,
_
=
torch
.
sort
(
index
)
return
data
[
index
],
index
def
extract_top_gradients
(
self
):
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
if
self
.
lower_bound
==
0.0
:
return
self
.
apply_wavelet
()
data
=
self
.
pre_share_model_transformed
if
self
.
change_based_selection
:
diff
=
self
.
model
.
model_change
_
,
index
=
torch
.
topk
(
diff
.
abs
(),
round
(
self
.
alpha
*
len
(
diff
)),
dim
=
0
,
sorted
=
False
,
)
else
:
_
,
index
=
torch
.
topk
(
data
.
abs
(),
round
(
self
.
alpha
*
len
(
data
)),
dim
=
0
,
sorted
=
False
,
)
ind
,
_
=
torch
.
sort
(
index
)
if
self
.
communication_round
>
self
.
start_lower_bounding_at
:
# because the superclass increases it where it is inconvenient for this subclass
currently_shared
=
self
.
model
.
shared_parameters_counter
.
clone
().
detach
()
currently_shared
[
ind
]
+=
1
ind_small
=
(
currently_shared
<
self
.
communication_round
*
self
.
lower_bound
).
nonzero
(
as_tuple
=
True
)[
0
]
ind_small_unique
=
np
.
setdiff1d
(
ind_small
.
numpy
(),
ind
.
numpy
(),
assume_unique
=
True
)
take_max
=
round
(
self
.
lower_bound
*
self
.
alpha
*
data
.
shape
[
0
])
logging
.
info
(
"
lower: %i %i %i
"
,
len
(
ind_small
),
len
(
ind_small_unique
),
take_max
)
if
take_max
>
ind_small_unique
.
shape
[
0
]:
take_max
=
ind_small_unique
.
shape
[
0
]
to_take
=
torch
.
rand
(
ind_small_unique
.
shape
[
0
])
_
,
ind_of_to_take
=
torch
.
topk
(
to_take
,
take_max
,
dim
=
0
,
sorted
=
False
)
ind_bound
=
torch
.
from_numpy
(
ind_small_unique
)[
ind_of_to_take
]
logging
.
info
(
"
lower bounding: %i %i
"
,
len
(
ind
),
len
(
ind_bound
))
# val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
ind
=
torch
.
cat
([
ind
,
ind_bound
])
index
,
_
=
torch
.
sort
(
ind
)
return
_
,
index
def
serialized_model
(
self
):
"""
Convert model to json dict. self.alpha specifies the fraction of model to send.
Returns
-------
dict
Model converted to json dict
"""
m
=
dict
()
if
self
.
alpha
>=
self
.
metadata_cap
:
# Share fully
data
=
self
.
pre_share_model_transformed
m
[
"
params
"
]
=
data
.
numpy
()
if
self
.
model
.
accumulated_changes
is
not
None
:
self
.
model
.
accumulated_changes
=
torch
.
zeros_like
(
self
.
model
.
accumulated_changes
)
return
m
with
torch
.
no_grad
():
topk
,
indices
=
self
.
apply_wavelet
()
self
.
model
.
shared_parameters_counter
[
indices
]
+=
1
self
.
model
.
rewind_accumulation
(
indices
)
if
self
.
save_shared
:
shared_params
=
dict
()
shared_params
[
"
order
"
]
=
list
(
self
.
model
.
state_dict
().
keys
())
shapes
=
dict
()
for
k
,
v
in
self
.
model
.
state_dict
().
items
():
shapes
[
k
]
=
list
(
v
.
shape
)
shared_params
[
"
shapes
"
]
=
shapes
shared_params
[
self
.
communication_round
]
=
indices
.
tolist
()
# is slow
shared_params
[
"
alpha
"
]
=
self
.
alpha
with
open
(
os
.
path
.
join
(
self
.
folder_path
,
"
{}_shared_params.json
"
.
format
(
self
.
communication_round
+
1
),
),
"
w
"
,
)
as
of
:
json
.
dump
(
shared_params
,
of
)
if
not
self
.
dict_ordered
:
raise
NotImplementedError
m
[
"
alpha
"
]
=
self
.
alpha
m
[
"
params
"
]
=
topk
.
numpy
()
m
[
"
indices
"
]
=
indices
.
numpy
().
astype
(
np
.
int32
)
m
[
"
send_partial
"
]
=
True
return
m
def
deserialized_model_avg
(
self
,
m
):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
dict received
Returns
-------
state_dict
state_dict of received
"""
if
"
send_partial
"
not
in
m
:
return
super
().
deserialized_model
(
m
)
with
torch
.
no_grad
():
state_dict
=
self
.
model
.
state_dict
()
if
not
self
.
dict_ordered
:
raise
NotImplementedError
# could be made more efficent
T
=
torch
.
zeros_like
(
self
.
init_model
)
index_tensor
=
torch
.
tensor
(
m
[
"
indices
"
],
dtype
=
torch
.
long
)
logging
.
debug
(
"
Original tensor: {}
"
.
format
(
T
[
index_tensor
]))
T
[
index_tensor
]
=
torch
.
tensor
(
m
[
"
params
"
])
logging
.
debug
(
"
Final tensor: {}
"
.
format
(
T
[
index_tensor
]))
return
T
,
index_tensor
def
deserialized_model
(
self
,
m
):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
ret
=
dict
()
if
"
send_partial
"
not
in
m
:
params
=
m
[
"
params
"
]
params_tensor
=
torch
.
tensor
(
params
)
ret
[
"
params
"
]
=
params_tensor
return
ret
with
torch
.
no_grad
():
if
not
self
.
dict_ordered
:
raise
NotImplementedError
alpha
=
m
[
"
alpha
"
]
params_tensor
=
torch
.
tensor
(
m
[
"
params
"
])
indices_tensor
=
torch
.
tensor
(
m
[
"
indices
"
],
dtype
=
torch
.
long
)
ret
=
dict
()
ret
[
"
indices
"
]
=
indices_tensor
ret
[
"
params
"
]
=
params_tensor
ret
[
"
send_partial
"
]
=
True
return
ret
def
_averaging
(
self
):
"""
Averages the received model with the local model
"""
with
torch
.
no_grad
():
total
=
None
weight_total
=
0
wt_params
=
self
.
pre_share_model_transformed
if
not
self
.
metro_hastings
:
weight_vector
=
torch
.
ones_like
(
wt_params
)
datas
=
[]
for
i
,
n
in
enumerate
(
self
.
peer_deques
):
degree
,
iteration
,
data
=
self
.
peer_deques
[
n
].
popleft
()
logging
.
debug
(
"
Averaging model from neighbor {} of iteration {}
"
.
format
(
n
,
iteration
)
)
data
=
self
.
deserialized_model
(
data
)
params
=
data
[
"
params
"
]
if
"
indices
"
in
data
:
indices
=
data
[
"
indices
"
]
if
not
self
.
metro_hastings
:
weight_vector
[
indices
]
+=
1
topkwf
=
torch
.
zeros_like
(
wt_params
)
topkwf
[
indices
]
=
params
topkwf
=
topkwf
.
reshape
(
self
.
wt_shape
)
datas
.
append
(
topkwf
)
else
:
# use local data to complement
topkwf
=
wt_params
.
clone
().
detach
()
topkwf
[
indices
]
=
params
topkwf
=
topkwf
.
reshape
(
self
.
wt_shape
)
else
:
topkwf
=
params
.
reshape
(
self
.
wt_shape
)
if
not
self
.
metro_hastings
:
weight_vector
+=
1
datas
.
append
(
topkwf
)
if
self
.
metro_hastings
:
weight
=
1
/
(
max
(
len
(
self
.
peer_deques
),
degree
)
+
1
)
# Metro-Hastings
weight_total
+=
weight
if
total
is
None
:
total
=
weight
*
topkwf
else
:
total
+=
weight
*
topkwf
if
not
self
.
metro_hastings
:
weight_vector
=
1.0
/
weight_vector
# speed up by exploiting sparsity
total
=
wt_params
*
weight_vector
for
d
in
datas
:
total
+=
d
*
weight_vector
else
:
# Metro-Hastings
total
+=
(
1
-
weight_total
)
*
wt_params
avg_wf_params
=
pywt
.
array_to_coeffs
(
total
.
numpy
(),
self
.
coeff_slices
,
output_format
=
"
wavedec
"
)
reverse_total
=
torch
.
from_numpy
(
pywt
.
waverec
(
avg_wf_params
,
wavelet
=
self
.
wavelet
)
)
start_index
=
0
std_dict
=
{}
for
i
,
key
in
enumerate
(
self
.
model
.
state_dict
()):
end_index
=
start_index
+
self
.
lens
[
i
]
std_dict
[
key
]
=
reverse_total
[
start_index
:
end_index
].
reshape
(
self
.
shapes
[
i
]
)
start_index
=
end_index
self
.
model
.
load_state_dict
(
std_dict
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment