Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
D
decentralizepy
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Milos Vujasinovic
decentralizepy
Commits
f655967c
Commit
f655967c
authored
2 years ago
by
Milos Vujasinovic
Committed by
Rishi Sharma
2 years ago
Browse files
Options
Downloads
Patches
Plain Diff
Add Choco
parent
1070c082
Branches
main
Branches containing commit
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/decentralizepy/sharing/Choco.py
+448
-0
448 additions, 0 deletions
src/decentralizepy/sharing/Choco.py
with
448 additions
and
0 deletions
src/decentralizepy/sharing/Choco.py
0 → 100644
+
448
−
0
View file @
f655967c
import
logging
import
torch
from
collections
import
OrderedDict
from
decentralizepy.sharing.Sharing
import
Sharing
def
zeros_like_state_dict
(
state_dict
):
"""
Creates a new state dictionary such that it has same
layers (name and size) as the input state dictionary, but all values
are zero
Parameters
----------
state_dict: dict[str, torch.Tensor]
"""
result_dict
=
OrderedDict
()
for
tensor_name
,
tensor_values
in
state_dict
.
items
():
result_dict
[
tensor_name
]
=
torch
.
zeros_like
(
tensor_values
)
return
result_dict
def
get_dict_keys_and_check_matching
(
dict_1
,
dict_2
):
"""
Checks if keys of the two dictionaries match and
reutrns them if they do, otherwise raises ValueError
Parameters
----------
dict_1: dict
dict_2: dict
Raises
------
ValueError
If the keys of the dictionaries don
'
t match
"""
keys
=
dict_1
.
keys
()
if
set
(
keys
).
difference
(
set
(
dict_2
.
keys
())):
raise
ValueError
(
'
Dictionaries must have matching keys
'
)
return
keys
def
subtract_state_dicts
(
_1
,
_2
):
"""
Subtracts one state dictionary from another
Parameters
----------
_1: dict[str, torch.Tensor]
Minuend state dictionary
_2: dict[str, torch.Tensor]
Subtrahend state dictionary
Raises
------
ValueError
If the keys of the state dictionaries don
'
t match
"""
keys
=
get_dict_keys_and_check_matching
(
_1
,
_2
)
result_dict
=
OrderedDict
()
for
key
in
keys
:
# Size checking is done by torch during the subtraction
result_dict
[
key
]
=
_1
[
key
]
-
_2
[
key
]
return
result_dict
def
self_add_state_dict
(
_1
,
_2
,
constant
=
1.
):
"""
Scales one state dictionary by a constant and
adds it directly to another minimizing copies
created. Equivalent to operation `_1 += constant * _2`
Parameters
----------
_1: dict[str, torch.Tensor]
State dictionary
_2: dict[str, torch.Tensor]
State dictionary
constant: float
Constant to scale _2 with
Raises
------
ValueError
If the keys of the state dictionaries don
'
t match
"""
keys
=
get_dict_keys_and_check_matching
(
_1
,
_2
)
for
key
in
keys
:
# Size checking is done by torch during the subtraction
_1
[
key
]
+=
constant
*
_2
[
key
]
def
flatten_state_dict
(
state_dict
):
"""
Transforms state dictionary into a flat tensor
by flattening and concatenating tensors of the
state dictionary.
Note: changes made to the result won
'
t affect state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to flatten
"""
return
torch
.
cat
([
tensor
.
flatten
()
\
for
tensor
in
state_dict
.
values
()
],
axis
=
0
)
def
unflatten_state_dict
(
flat_tensor
,
reference_state_dict
):
"""
Transforms a falt tensor into a state dictionary
by using another state dictionary as a reference
for size and names of the tensors. Assumes
that the number of elements of the flat tensor
is the same as the number of elements in the
reference state dictionary.
This operation is inverse operation to flatten_state_dict
Note: changes made to the result will affect the flat tensor
Parameters
----------
flat_tensor : torch.tensor
A 1-dim tensor
reference_state_dict : OrderedDict[str, torch.tensor]
A state dictionary used as a reference for tensor names
and shapes of the result
"""
result
=
OrderedDict
()
start_index
=
0
for
tensor_name
,
tensor
in
reference_state_dict
.
items
():
end_index
=
start_index
+
tensor
.
numel
()
result
[
tensor_name
]
=
flat_tensor
[
start_index
:
end_index
].
reshape
(
tensor
.
shape
)
start_index
=
end_index
return
result
def
serialize_sparse_tensor
(
tensor
):
"""
Serializes sparse tensor by flattening it and
returning values and indices of it that are not 0
Parameters
----------
tensor: torch.Tensor
"""
flat
=
tensor
.
flatten
()
indices
=
flat
.
nonzero
(
as_tuple
=
True
)[
0
]
values
=
flat
[
indices
]
return
values
,
indices
def
deserialize_sparse_tensor
(
values
,
indices
,
shape
):
"""
Deserializes tensor from its non-zero values and indices
in flattened form and original shape of the tensor.
Parameters
----------
values: torch.Tensor
Non-zero entries of flattened original tensor
indices: torch.Tensor
Respective indices of non-zero entries of flattened original tensor
shape: torch.Size or tuple[*int]
Shape of the original tensor
"""
result
=
torch
.
zeros
(
size
=
shape
)
if
len
(
indices
):
flat_result
=
result
.
flatten
()
flat_result
[
indices
]
=
values
return
result
def
topk_sparsification_tensor
(
tensor
,
alpha
):
"""
Performs topk sparsification of a tensor and returns
the same tensor from the input but transformed.
Note: no copies are created, but input vector is directly changed
Parameters
----------
tensor : torch.tensor
A tensor to perform the sparsification on
alpha : float
Percentage of topk values to keep
"""
tensor_abs
=
tensor
.
abs
()
flat_abs_tensor
=
tensor_abs
.
flatten
()
numel_to_keep
=
round
(
alpha
*
flat_abs_tensor
.
numel
())
if
numel_to_keep
>
0
:
cutoff_value
,
_
=
torch
.
kthvalue
(
-
flat_abs_tensor
,
numel_to_keep
)
tensor
[
tensor_abs
<
-
cutoff_value
]
=
0
return
tensor
def
topk_sparsification
(
state_dict
,
alpha
):
"""
Performs topk sparsification of a state_dict
applying it over all elements together.
Note: the changes made to the result won
'
t affect
the input state dictionary
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
A state dictionary to perform the sparsification on
alpha : float
Percentage of topk values to keep
"""
flat_tensor
=
flatten_state_dict
(
state_dict
)
return
unflatten_state_dict
(
topk_sparsification_tensor
(
flat_tensor
,
alpha
),
state_dict
)
def
serialize_sparse_state_dict
(
state_dict
):
with
torch
.
no_grad
():
concatted_tensors
=
torch
.
cat
([
tensor
.
flatten
()
\
for
tensor
in
state_dict
.
values
()
],
axis
=
0
)
return
serialize_sparse_tensor
(
concatted_tensors
)
def
deserialize_sparse_state_dict
(
values
,
indices
,
reference_state_dict
):
with
torch
.
no_grad
():
keys
=
[]
lens
=
[]
shapes
=
[]
for
k
,
v
in
reference_state_dict
.
items
():
keys
.
append
(
k
)
shapes
.
append
(
v
.
shape
)
lens
.
append
(
v
.
numel
())
total_num_el
=
sum
(
lens
)
T
=
deserialize_sparse_tensor
(
values
,
indices
,
(
total_num_el
,))
result_state_dict
=
OrderedDict
()
start_index
=
0
for
i
,
k
in
enumerate
(
keys
):
end_index
=
start_index
+
lens
[
i
]
result_state_dict
[
k
]
=
T
[
start_index
:
end_index
].
reshape
(
shapes
[
i
])
start_index
=
end_index
return
result_state_dict
class
Choco
(
Sharing
):
"""
API defining who to share with and what, and what to do on receiving
"""
def
__init__
(
self
,
rank
,
machine_id
,
communication
,
mapping
,
graph
,
model
,
dataset
,
log_dir
,
step_size
,
alpha
,
compress
=
False
,
compression_package
=
None
,
compression_class
=
None
,
):
"""
Constructor
Parameters
----------
rank : int
Local rank
machine_id : int
Global machine id
communication : decentralizepy.communication.Communication
Communication module used to send and receive messages
mapping : decentralizepy.mappings.Mapping
Mapping (rank, machine_id) -> uid
graph : decentralizepy.graphs.Graph
Graph reprensenting neighbors
model : decentralizepy.models.Model
Model to train
dataset : decentralizepy.datasets.Dataset
Dataset for sharing data. Not implemented yet! TODO
log_dir : str
Location to write shared_params (only writing for 2 procs per machine)
step_size : float
Step size from the formulation of Choco
alpha : float
Percentage of elements to keep during topk sparsification
"""
super
().
__init__
(
rank
,
machine_id
,
communication
,
mapping
,
graph
,
model
,
dataset
,
log_dir
,
compress
=
False
,
compression_package
=
None
,
compression_class
=
None
)
self
.
step_size
=
step_size
self
.
alpha
=
alpha
logging
.
info
(
"
type(step_size): %s, value: %s
"
,
str
(
type
(
self
.
step_size
)),
str
(
self
.
step_size
))
logging
.
info
(
"
type(alpha): %s, value: %s
"
,
str
(
type
(
self
.
alpha
)),
str
(
self
.
alpha
))
model_state_dict
=
model
.
state_dict
()
self
.
model_hat
=
zeros_like_state_dict
(
model_state_dict
)
self
.
s
=
zeros_like_state_dict
(
model_state_dict
)
self
.
my_q
=
None
def
compress_data
(
self
,
data
):
result
=
dict
(
data
)
if
self
.
compress
:
if
"
indices
"
in
result
:
result
[
"
indices
"
]
=
self
.
compressor
.
compress
(
result
[
"
indices
"
])
if
"
params
"
in
result
:
result
[
"
params
"
]
=
self
.
compressor
.
compress_float
(
result
[
"
params
"
])
return
result
def
decompress_data
(
self
,
data
):
if
self
.
compress
:
if
"
indices
"
in
data
:
data
[
"
indices
"
]
=
self
.
compressor
.
decompress
(
data
[
"
indices
"
])
if
"
params
"
in
data
:
data
[
"
params
"
]
=
self
.
compressor
.
decompress_float
(
data
[
"
params
"
])
return
data
def
_compress
(
self
,
x
):
return
topk_sparsification
(
x
,
self
.
alpha
)
def
_pre_step
(
self
):
"""
Called at the beginning of step.
"""
with
torch
.
no_grad
():
self
.
my_q
=
self
.
_compress
(
subtract_state_dicts
(
self
.
model
.
state_dict
(),
self
.
model_hat
))
def
serialized_model
(
self
):
"""
Convert self q to a dictionary. Here we can choose how much to share
Returns
-------
dict
Model converted to dict
"""
values
,
indices
=
serialize_sparse_state_dict
(
self
.
my_q
)
data
=
dict
()
data
[
"
params
"
]
=
values
.
numpy
()
data
[
"
indices
"
]
=
indices
.
numpy
()
data
[
"
send_partial
"
]
=
True
return
self
.
compress_data
(
data
)
def
deserialized_model
(
self
,
m
):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
if
"
send_partial
"
not
in
m
:
return
super
().
deserialized_model
(
m
)
with
torch
.
no_grad
():
m
=
self
.
decompress_data
(
m
)
indices
=
torch
.
tensor
(
m
[
"
indices
"
],
dtype
=
torch
.
long
)
values
=
torch
.
tensor
(
m
[
"
params
"
])
return
deserialize_sparse_state_dict
(
values
,
indices
,
self
.
model
.
state_dict
())
def
_averaging
(
self
,
peer_deques
):
"""
Averages the received model with the local model
"""
with
torch
.
no_grad
():
self_add_state_dict
(
self
.
model_hat
,
self
.
my_q
)
# x_hat = q_self + x_hat
weight_total
=
0
for
i
,
n
in
enumerate
(
peer_deques
):
data
=
peer_deques
[
n
].
popleft
()
degree
,
iteration
=
data
[
"
degree
"
],
data
[
"
iteration
"
]
del
data
[
"
degree
"
]
del
data
[
"
iteration
"
]
del
data
[
"
CHANNEL
"
]
logging
.
debug
(
"
Averaging model from neighbor {} of iteration {}
"
.
format
(
n
,
iteration
)
)
data
=
self
.
deserialized_model
(
data
)
# Metro-Hastings
weight
=
1
/
(
max
(
len
(
peer_deques
),
degree
)
+
1
)
weight_total
+=
weight
for
key
,
value
in
data
.
items
():
if
key
in
self
.
s
:
self
.
s
[
key
]
+=
value
*
weight
# else:
# self.s[key] = value * weight
for
key
,
value
in
self
.
my_q
.
items
():
self
.
s
[
key
]
+=
(
1
-
weight_total
)
*
value
# Metro-Hastings
total
=
self
.
model
.
state_dict
().
copy
()
self_add_state_dict
(
total
,
subtract_state_dicts
(
self
.
s
,
self
.
model_hat
),
constant
=
self
.
step_size
)
# x = x + gamma * (s - x_hat)
self
.
model
.
load_state_dict
(
total
)
self
.
_post_step
()
self
.
communication_round
+=
1
def
_averaging_server
(
self
,
peer_deques
):
"""
Averages the received models of all working nodes
"""
raise
NotImplementedError
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment