Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
S
secure-aggregation
Manage
Activity
Members
Labels
Plan
Issues
0
Issue boards
Milestones
Wiki
Code
Merge requests
0
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
SaCS
Semester-Projects
fall22
secure-aggregation
Commits
948e733f
Commit
948e733f
authored
2 years ago
by
Milos Vujasinovic
Browse files
Options
Downloads
Patches
Plain Diff
Bug fixes
parent
65f1b43a
Branches
choco-compression-fix
Branches containing commit
No related tags found
No related merge requests found
Pipeline
#142353
failed with stages
in 0 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/decentralizepy/node/SecureCompressedAggregation.py
+206
-37
206 additions, 37 deletions
src/decentralizepy/node/SecureCompressedAggregation.py
with
206 additions
and
37 deletions
src/decentralizepy/node/SecureCompressedAggregation.py
+
206
−
37
View file @
948e733f
...
...
@@ -9,13 +9,10 @@ import torch
from
matplotlib
import
pyplot
as
plt
from
collections
import
OrderedDict
import
numpy
as
np
import
contextlib
from
decentralizepy
import
utils
from
decentralizepy.graphs.Graph
import
Graph
from
decentralizepy.mappings.Mapping
import
Mapping
from
decentralizepy.node.Node
import
Node
from
decentralizepy.random
import
RandomState
,
temp_seed
from
decentralizepy.node.DPSGDNode
import
DPSGDNode
def
flatten_state_dict
(
state_dict
):
...
...
@@ -37,6 +34,19 @@ def flatten_state_dict(state_dict):
for
tensor
in
state_dict
.
values
()
],
axis
=
0
)
def
get_number_of_elements
(
state_dict
):
"""
Returns the number of parameters in the state dictionary
of a model.
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
The state dictionary of model
"""
return
sum
([
v
.
numel
()
for
v
in
state_dict
.
values
()])
def
unflatten_state_dict
(
flat_tensor
,
reference_state_dict
):
"""
Transforms a falt tensor into a state dictionary
...
...
@@ -69,20 +79,28 @@ def unflatten_state_dict(flat_tensor, reference_state_dict):
return
result
def
top_k
(
state_dict
,
alpha
):
flat_sd
=
flatten_state_dict
(
state_dict
)
num_el_to_keep
=
int
(
flat_sd
.
numel
()
*
alpha
)
parameters
,
indices
=
torch
.
topk
(
flat_sd
,
num_el_to_keep
,
largest
=
True
)
return
parameters
,
indices
@contextlib.contextmanager
def
temp_seed
(
seed
):
state
=
np
.
random
.
get_state
()
np
.
random
.
seed
(
seed
)
try
:
yield
finally
:
np
.
random
.
set_state
(
state
)
flat_sd
=
flatten_state_dict
(
state_dict
)
num_el_to_keep
=
int
(
flat_sd
.
numel
()
*
alpha
)
_
,
indices
=
torch
.
topk
(
flat_sd
.
abs
(),
num_el_to_keep
,
largest
=
True
)
return
flat_sd
,
indices
def
layerwise_topk
(
state_dict
,
alpha
):
indice_list
,
params_list
=
[],
[]
numel_so_far
=
0
for
_
,
v
in
state_dict
.
items
():
flat_tensor
=
v
.
flatten
()
num_el_to_keep
=
int
(
flat_tensor
.
numel
()
*
alpha
)
_
,
indices
=
torch
.
topk
(
flat_tensor
,
num_el_to_keep
,
largest
=
True
,
sorted
=
True
)
indices
,
_
=
torch
.
sort
(
indices
)
# print(indices)
indices
+=
numel_so_far
indice_list
.
append
(
indices
)
params_list
.
append
(
flat_tensor
)
numel_so_far
+=
flat_tensor
.
numel
()
selected_indices
=
torch
.
cat
(
indice_list
)
flat_params
=
torch
.
cat
(
params_list
)
return
flat_params
,
selected_indices
class
SecureCompressedAggregation
(
DPSGDNode
):
"""
...
...
@@ -93,7 +111,7 @@ class SecureCompressedAggregation(DPSGDNode):
def
get_neighbors
(
self
,
node
=
None
):
if
node
is
None
:
node
=
self
.
uid
return
self
.
graph
.
neighbors
(
node
)
return
self
.
graph
.
neighbors
(
node
)
def
get_distance2_neighbors
(
self
,
start_node
=
None
):
"""
...
...
@@ -106,9 +124,6 @@ class SecureCompressedAggregation(DPSGDNode):
nodes
.
remove
(
start_node
)
return
nodes
def
receive_DPSGD
(
self
):
return
self
.
receive_channel
(
"
DPSGD
"
)
def
connect_to_nodes
(
self
,
set_of_nodes
):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
...
...
@@ -129,10 +144,47 @@ class SecureCompressedAggregation(DPSGDNode):
for
node
in
wait_acknowledgements
:
self
.
wait_for_hello
(
node
)
def
aggregate_models
(
self
,
parameters
,
indices
):
def
_pseudo_pre_step
(
self
):
pre_share_model
=
flatten_state_dict
(
self
.
model
.
state_dict
()).
clone
()
change
=
pre_share_model
-
self
.
init_model
self
.
model
.
accumulated_changes
+=
change
change
=
self
.
model
.
accumulated_changes
.
clone
().
detach
()
self
.
model
.
model_change
=
change
def
_pseudo_post_step
(
self
):
post_share_model
=
flatten_state_dict
(
self
.
model
.
state_dict
()).
clone
()
self
.
init_model
=
post_share_model
self
.
model
.
accumulated_changes
+=
self
.
init_model
-
self
.
prev
self
.
prev
=
self
.
init_model
self
.
model
.
model_change
=
None
def
top_k_changed
(
self
,
state_dict
,
alpha
):
flat_sd
=
flatten_state_dict
(
state_dict
)
flat_changes
=
torch
.
abs
(
self
.
model
.
model_change
)
num_el_to_keep
=
int
(
flat_sd
.
numel
()
*
alpha
)
_
,
indices
=
torch
.
topk
(
flat_changes
,
num_el_to_keep
,
largest
=
True
)
return
flat_sd
,
indices
def
random_subsampling
(
self
,
state_dict
,
alpha
):
flat_sd
=
flatten_state_dict
(
state_dict
)
logging
.
info
(
"
Subsampling mask seed: %d
"
,
torch
.
seed
())
keep_mask
=
torch
.
rand
(
flat_sd
.
shape
)
<
alpha
indices
=
keep_mask
.
nonzero
(
as_tuple
=
True
)[
0
]
return
flat_sd
,
indices
def
aggregate_models
(
self
,
parameters
,
indices
,
iteration
):
# return None
distance2_nodes
=
self
.
get_distance2_neighbors
()
logging
.
info
(
"
Neighbors: {}
"
.
format
(
self
.
get_neighbors
()))
logging
.
info
(
"
Distance 2 nodes: {}
"
.
format
(
distance2_nodes
))
self
.
connect_to_nodes
(
distance2_nodes
)
compressed_indices
=
self
.
sharing
.
compressor
.
compress
(
indices
.
numpy
())
# Generating and sending pairwise masks
sent_masks
=
{}
for
node
in
distance2_nodes
:
...
...
@@ -141,32 +193,57 @@ class SecureCompressedAggregation(DPSGDNode):
self
.
communication
.
send
(
node
,
{
"
seed
"
:
mask_seed
,
"
indices
"
:
compressed_indices
,
"
iteration
"
:
iteration
,
"
CHANNEL
"
:
"
PRE-SECURE-AGG-STEP
"
})
logging
.
info
(
"
Sent mask to %d
"
,
node
)
# Receiving pairwise masks and indices
received_data
=
{}
waiting_mask_from
=
distance2_nodes
.
copy
()
# Processing masks received before the given round
for
sender
,
mask_data
in
self
.
masks_received_early
:
if
mask_data
[
"
iteration
"
]
!=
iteration
:
raise
ValueError
(
"
Mask iterations don
'
t match
"
)
del
mask_data
[
"
iteration
"
]
received_data
[
sender
]
=
mask_data
received_data
[
sender
][
"
indices
"
]
=
torch
.
tensor
(
self
.
sharing
.
compressor
.
decompress
(
received_data
[
sender
][
"
indices
"
]),
dtype
=
torch
.
long
)
waiting_mask_from
.
remove
(
sender
)
self
.
masks_received_early
=
[]
# Waiting for other masks
while
waiting_mask_from
:
sender
,
data
=
self
.
receive_channel
(
"
PRE-SECURE-AGG-STEP
"
)
del
data
[
"
CHANNEL
"
]
if
sender
in
waiting_mask_from
:
# print('Seed from', sender, 'is', data["seed"])
del
data
[
"
CHANNEL
"
]
if
data
[
"
iteration
"
]
!=
iteration
:
raise
ValueError
(
"
Mask iterations don
'
t match
"
)
del
data
[
"
iteration
"
]
received_data
[
sender
]
=
data
received_data
[
sender
][
"
indices
"
]
=
torch
.
tensor
(
self
.
sharing
.
compressor
.
decompress
(
received_data
[
sender
][
"
indices
"
]),
dtype
=
torch
.
long
)
waiting_mask_from
.
remove
(
sender
)
else
:
self
.
masks_received_early
.
append
((
sender
,
data
))
# Building masks
pairwise_mask_difference
=
{}
indices_size
=
indices
.
size
()[
0
]
logging
.
info
(
"
Indices intended to share: %s
"
,
indices_size
)
for
node
,
data
in
received_data
.
items
():
# sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray)
# torch.topk doesn't return indices sorted...
_
,
my_indices_pos
,
_
=
np
.
intersect1d
(
indices
,
data
[
"
indices
"
],
return_indices
=
True
)
mask_shape
=
my_indices_pos
.
shape
# logging.info("My indices %d, neighbors indices %d, intersect %d", indices.size()[0], data["indices"].size()[0], my_indices_pos.size)
logging
.
info
(
"
Indice intersects: %s
"
,
my_indices_pos
.
size
)
pairwise_mask_difference
[
node
]
=
{
"
value
"
:
(
self
.
generate_mask
(
sent_masks
[
node
],
mask_shape
)
-
self
.
generate_mask
(
data
[
"
seed
"
],
mask_shape
)).
double
(),
"
indices
"
:
my_indices_pos
}
# Sending models to neighbors
# print(indices)
self
.
my_neighbors
=
self
.
get_neighbors
()
self
.
connect_to_nodes
(
self
.
my_neighbors
)
for
neighbor
in
self
.
my_neighbors
:
...
...
@@ -178,38 +255,61 @@ class SecureCompressedAggregation(DPSGDNode):
if
self
.
uid
==
pairing_node
:
continue
pair_mask
=
pairwise_mask_difference
[
pairing_node
][
"
value
"
]
pair_indices
=
pairwise_mask_difference
[
pairing_node
][
"
indices
"
]
pair_indices_pos
=
pairwise_mask_difference
[
pairing_node
][
"
indices
"
]
pair_indices
=
indices
[
pair_indices_pos
]
# print(perturbated_model.shape, pair_indices.shape, pair_mask.shape)
# print(perturbated_model.dtype, pair_mask.dtype)
perturbated_model
[
pair_indices
]
+=
pair_mask
masking_count
[
pair_indices
]
+=
1
masking_count
[
pair_indices
_pos
]
+=
1
non_zero_indices
=
masking_count
.
nonzero
(
as_tuple
=
True
)[
0
]
indices_to_send
=
indices
[
non_zero_indices
]
parameters_to_send
=
parameters
[
non_zero_indices
]
parameters_to_send
=
parameters
[
indices_to_send
]
# Debug to 'skip' protocol (delete later)
# parameters_to_send = parameters[indices]
# indices_to_send = indices
logging
.
info
(
'
Sending indices: %d
'
,
indices_to_send
.
shape
[
0
])
compressed_parameters
=
self
.
sharing
.
compressor
.
compress_float
(
parameters_to_send
.
numpy
())
compressed_indices
=
self
.
sharing
.
compressor
.
compress
(
indices_to_send
.
numpy
())
self
.
communication
.
send
(
neighbor
,
{
"
params
"
:
compressed_parameters
,
"
indices
"
:
compressed_indices
,
"
iteration
"
:
iteration
,
"
CHANNEL
"
:
"
SECURE_MODEL_CHANNEL
"
})
logging
.
info
(
"
Sent model to %d
"
,
neighbor
)
# Receiving models from neighbors
received_models
=
{}
waiting_models_from
=
self
.
my_neighbors
.
copy
()
for
sender
,
model_data
in
self
.
models_received_early
:
if
model_data
[
"
iteration
"
]
!=
iteration
:
raise
ValueError
(
"
Model iterations don
'
t match
"
)
del
model_data
[
"
iteration
"
]
received_models
[
sender
]
=
model_data
received_models
[
sender
][
"
indices
"
]
=
torch
.
tensor
(
self
.
sharing
.
compressor
.
decompress
(
received_models
[
sender
][
"
indices
"
]),
dtype
=
torch
.
long
)
received_models
[
sender
][
"
params
"
]
=
torch
.
tensor
(
self
.
sharing
.
compressor
.
decompress_float
(
received_models
[
sender
][
"
params
"
]))
waiting_models_from
.
remove
(
sender
)
self
.
models_received_early
=
[]
while
waiting_models_from
:
# print(self.uid, "Waiting models from:", waiting_models_from)
sender
,
data
=
self
.
receive_channel
(
"
SECURE_MODEL_CHANNEL
"
)
del
data
[
"
CHANNEL
"
]
if
sender
in
waiting_models_from
:
# print('Seed from', sender, 'is', data["seed"])
del
data
[
"
CHANNEL
"
]
if
data
[
"
iteration
"
]
!=
iteration
:
raise
ValueError
(
"
Model iterations don
'
t match
"
)
del
data
[
"
iteration
"
]
received_models
[
sender
]
=
data
received_models
[
sender
][
"
indices
"
]
=
torch
.
tensor
(
self
.
sharing
.
compressor
.
decompress
(
received_models
[
sender
][
"
indices
"
]),
dtype
=
torch
.
long
)
received_models
[
sender
][
"
params
"
]
=
torch
.
tensor
(
self
.
sharing
.
compressor
.
decompress_float
(
received_models
[
sender
][
"
params
"
]))
waiting_models_from
.
remove
(
sender
)
else
:
self
.
models_received_early
.
append
((
sender
,
data
))
# Averaging
weight
=
1
/
(
len
(
self
.
my_neighbors
)
+
1
)
preshare_model
=
flatten_state_dict
(
self
.
model
.
state_dict
())
...
...
@@ -221,21 +321,53 @@ class SecureCompressedAggregation(DPSGDNode):
recovered_model
[
indices
]
=
params
new_flat_model
+=
weight
*
recovered_model
# Loading new state state dictionary
logging
.
info
(
'
L0=
'
+
str
((
parameters
-
new_flat_model
).
abs
().
sum
()))
logging
.
info
(
'
model_L0=
'
+
str
((
parameters
).
abs
().
sum
()))
new_state_dict
=
unflatten_state_dict
(
new_flat_model
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
new_state_dict
)
def
generate_mask
(
self
,
seed
,
size
):
with
temp_seed
(
seed
):
# Figure out best distribution to add
return
torch
.
Tensor
(
np
.
random
.
normal
(
0
,
100000
,
size
=
size
))
return
torch
.
Tensor
(
np
.
random
.
uniform
(
-
10000000
,
20000000
,
size
=
size
))
def
extract_top_gradients
(
self
):
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
logging
.
info
(
"
Returning topk gradients
"
)
G_topk
=
torch
.
abs
(
self
.
model
.
model_change
)
std
,
mean
=
torch
.
std_mean
(
G_topk
,
unbiased
=
False
)
self
.
std
=
std
.
item
()
self
.
mean
=
mean
.
item
()
_
,
index
=
torch
.
topk
(
G_topk
,
round
(
self
.
alpha
*
G_topk
.
shape
[
0
]),
dim
=
0
,
sorted
=
True
)
index
,
_
=
torch
.
sort
(
index
)
return
_
,
index
def
run
(
self
):
"""
Start the decentralized learning
"""
torch
.
manual_seed
(
self
.
uid
)
np
.
random
.
seed
(
self
.
uid
)
# logging.info("Start, Np num: %f, torch num: %f", np.random.random(), torch.rand((1,))[0])
with
torch
.
no_grad
():
self
.
init_model
=
flatten_state_dict
(
self
.
model
.
state_dict
())
self
.
model
.
accumulated_changes
=
torch
.
zeros_like
(
self
.
init_model
)
self
.
prev
=
self
.
init_model
self
.
sec_agg_state
=
RandomState
(
self
.
uid
)
self
.
testset
=
self
.
dataset
.
get_testset
()
rounds_to_test
=
self
.
test_after
...
...
@@ -243,16 +375,53 @@ class SecureCompressedAggregation(DPSGDNode):
global_epoch
=
1
change
=
1
self
.
old_model_holder
=
flatten_state_dict
(
self
.
model
.
state_dict
()).
clone
()
self
.
model
.
accumulated_changes
=
torch
.
zeros_like
(
self
.
old_model_holder
)
self
.
masks_received_early
=
[]
self
.
models_received_early
=
[]
# logging.info("Before iter, Np num: %f, torch num: %f", np.random.random(), torch.rand((1,))[0])
logging
.
info
(
"
Number of parameters in model: %d
"
,
get_number_of_elements
(
self
.
model
.
state_dict
()))
for
iteration
in
range
(
self
.
iterations
):
if
self
.
uid
==
0
:
print
(
"
Iteration
"
,
iteration
)
logging
.
info
(
"
Starting training iteration: %d
"
,
iteration
)
rounds_to_train_evaluate
-=
1
rounds_to_test
-=
1
# logging.info("Iteration %d before train, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
self
.
iteration
=
iteration
self
.
trainer
.
train
(
self
.
dataset
)
self
.
aggregate_models
(
*
top_k
(
self
.
model
.
state_dict
(),
0.3
))
# logging.info("Iteration %d before share, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
self
.
_pseudo_pre_step
()
# self.aggregate_models(*top_k(self.model.state_dict(), 0.3), iteration)
# self.aggregate_models(*self.random_subsampling(self.model.state_dict(), 0.3), iteration)
# flat_model, indices_to_share = self.top_k_changed(self.model.state_dict(), 0.3)
flat_model
,
indices_to_share
=
self
.
top_k_changed
(
self
.
model
.
state_dict
(),
1
)
self
.
model
.
shared_parameters_counter
[
indices_to_share
]
+=
1
self
.
model
.
rewind_accumulation
(
indices_to_share
)
with
self
.
sec_agg_state
.
activate
():
self
.
aggregate_models
(
flat_model
,
indices_to_share
,
iteration
)
self
.
_pseudo_post_step
()
# logging.info("Iteration %d, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
if
self
.
reset_optimizer
:
self
.
optimizer
=
self
.
optimizer_class
(
self
.
model
.
parameters
(),
**
self
.
optimizer_params
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment