Skip to content
Snippets Groups Projects
Commit 948e733f authored by Milos Vujasinovic's avatar Milos Vujasinovic
Browse files

Bug fixes

parent 65f1b43a
Branches choco-compression-fix
No related tags found
No related merge requests found
Pipeline #142353 failed with stages
in 0 seconds
......@@ -9,13 +9,10 @@ import torch
from matplotlib import pyplot as plt
from collections import OrderedDict
import numpy as np
import contextlib
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
from decentralizepy.node.Node import Node
from decentralizepy.random import RandomState, temp_seed
from decentralizepy.node.DPSGDNode import DPSGDNode
def flatten_state_dict(state_dict):
......@@ -37,6 +34,19 @@ def flatten_state_dict(state_dict):
for tensor in state_dict.values()
], axis=0)
def get_number_of_elements(state_dict):
"""
Returns the number of parameters in the state dictionary
of a model.
Parameters
----------
state_dict : OrderedDict[str, torch.tensor]
The state dictionary of model
"""
return sum([v.numel() for v in state_dict.values()])
def unflatten_state_dict(flat_tensor, reference_state_dict):
"""
Transforms a falt tensor into a state dictionary
......@@ -69,20 +79,28 @@ def unflatten_state_dict(flat_tensor, reference_state_dict):
return result
def top_k(state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
num_el_to_keep = int(flat_sd.numel() * alpha)
parameters, indices = torch.topk(flat_sd, num_el_to_keep, largest=True)
return parameters, indices
@contextlib.contextmanager
def temp_seed(seed):
state = np.random.get_state()
np.random.seed(seed)
try:
yield
finally:
np.random.set_state(state)
flat_sd = flatten_state_dict(state_dict)
num_el_to_keep = int(flat_sd.numel() * alpha)
_, indices = torch.topk(flat_sd.abs(), num_el_to_keep, largest=True)
return flat_sd, indices
def layerwise_topk(state_dict, alpha):
indice_list, params_list = [], []
numel_so_far = 0
for _, v in state_dict.items():
flat_tensor = v.flatten()
num_el_to_keep = int(flat_tensor.numel() * alpha)
_, indices = torch.topk(flat_tensor, num_el_to_keep, largest=True, sorted=True)
indices, _ = torch.sort(indices)
# print(indices)
indices += numel_so_far
indice_list.append(indices)
params_list.append(flat_tensor)
numel_so_far += flat_tensor.numel()
selected_indices = torch.cat(indice_list)
flat_params = torch.cat(params_list)
return flat_params, selected_indices
class SecureCompressedAggregation(DPSGDNode):
"""
......@@ -93,7 +111,7 @@ class SecureCompressedAggregation(DPSGDNode):
def get_neighbors(self, node=None):
if node is None:
node = self.uid
return self.graph.neighbors(node)
return self.graph.neighbors(node)
def get_distance2_neighbors(self, start_node=None):
"""
......@@ -106,9 +124,6 @@ class SecureCompressedAggregation(DPSGDNode):
nodes.remove(start_node)
return nodes
def receive_DPSGD(self):
return self.receive_channel("DPSGD")
def connect_to_nodes(self, set_of_nodes):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
......@@ -129,10 +144,47 @@ class SecureCompressedAggregation(DPSGDNode):
for node in wait_acknowledgements:
self.wait_for_hello(node)
def aggregate_models(self, parameters, indices):
def _pseudo_pre_step(self):
pre_share_model = flatten_state_dict(self.model.state_dict()).clone()
change = pre_share_model - self.init_model
self.model.accumulated_changes += change
change = self.model.accumulated_changes.clone().detach()
self.model.model_change = change
def _pseudo_post_step(self):
post_share_model = flatten_state_dict(self.model.state_dict()).clone()
self.init_model = post_share_model
self.model.accumulated_changes += self.init_model - self.prev
self.prev = self.init_model
self.model.model_change = None
def top_k_changed(self, state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
flat_changes = torch.abs(self.model.model_change)
num_el_to_keep = int(flat_sd.numel() * alpha)
_, indices = torch.topk(flat_changes, num_el_to_keep, largest=True)
return flat_sd, indices
def random_subsampling(self, state_dict, alpha):
flat_sd = flatten_state_dict(state_dict)
logging.info("Subsampling mask seed: %d", torch.seed())
keep_mask = torch.rand(flat_sd.shape) < alpha
indices = keep_mask.nonzero(as_tuple=True)[0]
return flat_sd, indices
def aggregate_models(self, parameters, indices, iteration):
# return None
distance2_nodes = self.get_distance2_neighbors()
logging.info("Neighbors: {}".format(self.get_neighbors()))
logging.info("Distance 2 nodes: {}".format(distance2_nodes))
self.connect_to_nodes(distance2_nodes)
compressed_indices = self.sharing.compressor.compress(indices.numpy())
# Generating and sending pairwise masks
sent_masks = {}
for node in distance2_nodes:
......@@ -141,32 +193,57 @@ class SecureCompressedAggregation(DPSGDNode):
self.communication.send(node, {
"seed": mask_seed,
"indices": compressed_indices,
"iteration": iteration,
"CHANNEL": "PRE-SECURE-AGG-STEP"
})
logging.info("Sent mask to %d", node)
# Receiving pairwise masks and indices
received_data = {}
waiting_mask_from = distance2_nodes.copy()
# Processing masks received before the given round
for sender, mask_data in self.masks_received_early:
if mask_data["iteration"] != iteration:
raise ValueError("Mask iterations don't match")
del mask_data["iteration"]
received_data[sender] = mask_data
received_data[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
waiting_mask_from.remove(sender)
self.masks_received_early = []
# Waiting for other masks
while waiting_mask_from:
sender, data = self.receive_channel("PRE-SECURE-AGG-STEP")
del data["CHANNEL"]
if sender in waiting_mask_from:
# print('Seed from', sender, 'is', data["seed"])
del data["CHANNEL"]
if data["iteration"] != iteration:
raise ValueError("Mask iterations don't match")
del data["iteration"]
received_data[sender] = data
received_data[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_data[sender]["indices"]), dtype=torch.long)
waiting_mask_from.remove(sender)
else:
self.masks_received_early.append((sender, data))
# Building masks
pairwise_mask_difference = {}
indices_size = indices.size()[0]
logging.info("Indices intended to share: %s", indices_size)
for node, data in received_data.items():
# sortednp.intersect supports intersection of sorted array (make sure to cast tensor to nparray)
# torch.topk doesn't return indices sorted...
_, my_indices_pos, _ = np.intersect1d(indices, data["indices"], return_indices=True)
mask_shape = my_indices_pos.shape
# logging.info("My indices %d, neighbors indices %d, intersect %d", indices.size()[0], data["indices"].size()[0], my_indices_pos.size)
logging.info("Indice intersects: %s", my_indices_pos.size)
pairwise_mask_difference[node] = {
"value": (self.generate_mask(sent_masks[node], mask_shape) - self.generate_mask(data["seed"], mask_shape)).double(),
"indices": my_indices_pos
}
# Sending models to neighbors
# print(indices)
self.my_neighbors = self.get_neighbors()
self.connect_to_nodes(self.my_neighbors)
for neighbor in self.my_neighbors:
......@@ -178,38 +255,61 @@ class SecureCompressedAggregation(DPSGDNode):
if self.uid == pairing_node:
continue
pair_mask = pairwise_mask_difference[pairing_node]["value"]
pair_indices = pairwise_mask_difference[pairing_node]["indices"]
pair_indices_pos = pairwise_mask_difference[pairing_node]["indices"]
pair_indices = indices[pair_indices_pos]
# print(perturbated_model.shape, pair_indices.shape, pair_mask.shape)
# print(perturbated_model.dtype, pair_mask.dtype)
perturbated_model[pair_indices] += pair_mask
masking_count[pair_indices] += 1
masking_count[pair_indices_pos] += 1
non_zero_indices = masking_count.nonzero(as_tuple=True)[0]
indices_to_send = indices[non_zero_indices]
parameters_to_send = parameters[non_zero_indices]
parameters_to_send = parameters[indices_to_send]
# Debug to 'skip' protocol (delete later)
# parameters_to_send = parameters[indices]
# indices_to_send = indices
logging.info('Sending indices: %d', indices_to_send.shape[0])
compressed_parameters = self.sharing.compressor.compress_float(parameters_to_send.numpy())
compressed_indices = self.sharing.compressor.compress(indices_to_send.numpy())
self.communication.send(neighbor, {
"params": compressed_parameters,
"indices": compressed_indices,
"iteration": iteration,
"CHANNEL": "SECURE_MODEL_CHANNEL"
})
logging.info("Sent model to %d", neighbor)
# Receiving models from neighbors
received_models = {}
waiting_models_from = self.my_neighbors.copy()
for sender, model_data in self.models_received_early:
if model_data["iteration"] != iteration:
raise ValueError("Model iterations don't match")
del model_data["iteration"]
received_models[sender] = model_data
received_models[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
received_models[sender]["params"] = torch.tensor(
self.sharing.compressor.decompress_float(received_models[sender]["params"]))
waiting_models_from.remove(sender)
self.models_received_early = []
while waiting_models_from:
# print(self.uid, "Waiting models from:", waiting_models_from)
sender, data = self.receive_channel("SECURE_MODEL_CHANNEL")
del data["CHANNEL"]
if sender in waiting_models_from:
# print('Seed from', sender, 'is', data["seed"])
del data["CHANNEL"]
if data["iteration"] != iteration:
raise ValueError("Model iterations don't match")
del data["iteration"]
received_models[sender] = data
received_models[sender]["indices"] = torch.tensor(
self.sharing.compressor.decompress(received_models[sender]["indices"]), dtype=torch.long)
received_models[sender]["params"] = torch.tensor(
self.sharing.compressor.decompress_float(received_models[sender]["params"]))
waiting_models_from.remove(sender)
else:
self.models_received_early.append((sender, data))
# Averaging
weight = 1 / (len(self.my_neighbors) + 1)
preshare_model = flatten_state_dict(self.model.state_dict())
......@@ -221,21 +321,53 @@ class SecureCompressedAggregation(DPSGDNode):
recovered_model[indices] = params
new_flat_model += weight * recovered_model
# Loading new state state dictionary
logging.info('L0=' + str((parameters-new_flat_model).abs().sum()))
logging.info('model_L0=' + str((parameters).abs().sum()))
new_state_dict = unflatten_state_dict(new_flat_model, self.model.state_dict())
self.model.load_state_dict(new_state_dict)
def generate_mask(self, seed, size):
with temp_seed(seed):
# Figure out best distribution to add
return torch.Tensor(np.random.normal(0, 100000, size=size))
return torch.Tensor(np.random.uniform(-10000000, 20000000, size=size))
def extract_top_gradients(self):
"""
Extract the indices and values of the topK gradients.
The gradients must have been accumulated.
Returns
-------
tuple
(a,b). a: The magnitudes of the topK gradients, b: Their indices.
"""
logging.info("Returning topk gradients")
G_topk = torch.abs(self.model.model_change)
std, mean = torch.std_mean(G_topk, unbiased=False)
self.std = std.item()
self.mean = mean.item()
_, index = torch.topk(
G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=True
)
index, _ = torch.sort(index)
return _, index
def run(self):
"""
Start the decentralized learning
"""
torch.manual_seed(self.uid)
np.random.seed(self.uid)
# logging.info("Start, Np num: %f, torch num: %f", np.random.random(), torch.rand((1,))[0])
with torch.no_grad():
self.init_model = flatten_state_dict(self.model.state_dict())
self.model.accumulated_changes = torch.zeros_like(
self.init_model)
self.prev = self.init_model
self.sec_agg_state = RandomState(self.uid)
self.testset = self.dataset.get_testset()
rounds_to_test = self.test_after
......@@ -243,16 +375,53 @@ class SecureCompressedAggregation(DPSGDNode):
global_epoch = 1
change = 1
self.old_model_holder = flatten_state_dict(self.model.state_dict()).clone()
self.model.accumulated_changes = torch.zeros_like(
self.old_model_holder)
self.masks_received_early = []
self.models_received_early = []
# logging.info("Before iter, Np num: %f, torch num: %f", np.random.random(), torch.rand((1,))[0])
logging.info("Number of parameters in model: %d",
get_number_of_elements(self.model.state_dict()))
for iteration in range(self.iterations):
if self.uid == 0:
print("Iteration", iteration)
logging.info("Starting training iteration: %d", iteration)
rounds_to_train_evaluate -= 1
rounds_to_test -= 1
# logging.info("Iteration %d before train, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
self.iteration = iteration
self.trainer.train(self.dataset)
self.aggregate_models(*top_k(self.model.state_dict(), 0.3))
# logging.info("Iteration %d before share, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
self._pseudo_pre_step()
# self.aggregate_models(*top_k(self.model.state_dict(), 0.3), iteration)
# self.aggregate_models(*self.random_subsampling(self.model.state_dict(), 0.3), iteration)
# flat_model, indices_to_share = self.top_k_changed(self.model.state_dict(), 0.3)
flat_model, indices_to_share = self.top_k_changed(self.model.state_dict(), 1)
self.model.shared_parameters_counter[indices_to_share] += 1
self.model.rewind_accumulation(indices_to_share)
with self.sec_agg_state.activate():
self.aggregate_models(flat_model, indices_to_share, iteration)
self._pseudo_post_step()
# logging.info("Iteration %d, NP state: %d, torch state: %d",
# iteration,
# np.random.get_state()[1].sum(),
# torch.random.get_rng_state().sum())
if self.reset_optimizer:
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment