Skip to content
Snippets Groups Projects
Commit bf5f03d1 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Merge branch 'compression_fixes' into 'main'

Fixes to the previous PR

See merge request sacs/decentralizepy!12
parents 623439ad 37c39c35
No related branches found
No related tags found
No related merge requests found
......@@ -131,25 +131,23 @@ class TCP(Communication):
if self.compress:
if "indices" in data:
data["indices"] = self.compressor.compress(data["indices"])
meta_len = len(
pickle.dumps(data["indices"])
) # ONLY necessary for the statistics
if "params" in data:
data["params"] = self.compressor.compress_float(data["params"])
assert "params" in data
data["params"] = self.compressor.compress_float(data["params"])
data_len = len(pickle.dumps(data["params"]))
output = pickle.dumps(data)
# the compressed meta data gets only a few bytes smaller after pickling
self.total_meta += meta_len
self.total_data += len(output) - meta_len
self.total_meta += len(output) - data_len
self.total_data += data_len
else:
output = pickle.dumps(data)
# centralized testing uses its own instance
if type(data) == dict:
if "indices" in data:
meta_len = len(pickle.dumps(data["indices"]))
else:
meta_len = 0
self.total_meta += meta_len
self.total_data += len(output) - meta_len
assert "params" in data
data_len = len(pickle.dumps(data["params"]))
self.total_meta += len(output) - data_len
self.total_data += data_len
return output
def decrypt(self, sender, data):
......
......@@ -49,4 +49,4 @@ class EliasFpzip(Elias):
decompressed data as array
"""
return fpzip.decompress(bytes, order="C")
return fpzip.decompress(bytes, order="C").squeeze()
......@@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias):
decompressed data as array
"""
return fpzip.decompress(bytes, order="C")
return fpzip.decompress(bytes, order="C").squeeze()
......@@ -159,7 +159,6 @@ class FFT(PartialModel):
if self.alpha >= self.metadata_cap: # Share fully
data = self.pre_share_model_transformed
m["params"] = data.numpy()
self.total_data += len(self.communication.encrypt(m["params"]))
if self.model.accumulated_changes is not None:
self.model.accumulated_changes = torch.zeros_like(
self.model.accumulated_changes
......@@ -200,11 +199,6 @@ class FFT(PartialModel):
m["indices"] = indices.numpy().astype(np.int32)
m["send_partial"] = True
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["alpha"])
)
return m
def deserialized_model(self, m):
......
......@@ -82,7 +82,6 @@ class PartialModel(Sharing):
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
self.accumulation = accumulation
self.save_accumulated = conditional_value(save_accumulated, "", False)
self.change_transformer = change_transformer
......
......@@ -19,7 +19,7 @@ class RandomAlpha(Wavelet):
model,
dataset,
log_dir,
alpha_list=[0.1, 0.2, 0.3, 0.4, 1.0],
alpha_list="[0.1, 0.2, 0.3, 0.4, 1.0]",
dict_ordered=True,
save_shared=False,
metadata_cap=1.0,
......
......@@ -46,13 +46,20 @@ class Sharing:
self.dataset = dataset
self.communication_round = 0
self.log_dir = log_dir
self.total_data = 0
self.peer_deques = dict()
self.my_neighbors = self.graph.neighbors(self.uid)
for n in self.my_neighbors:
self.peer_deques[n] = deque()
self.shapes = []
self.lens = []
with torch.no_grad():
for _, v in self.model.state_dict().items():
self.shapes.append(v.shape)
t = v.flatten().numpy()
self.lens.append(t.shape[0])
def received_from_all(self):
"""
Check if all neighbors have sent the current iteration
......@@ -96,11 +103,15 @@ class Sharing:
Model converted to dict
"""
m = dict()
for key, val in self.model.state_dict().items():
m[key] = val.numpy()
self.total_data += len(self.communication.encrypt(m[key]))
return m
to_cat = []
with torch.no_grad():
for _, v in self.model.state_dict().items():
t = v.flatten()
to_cat.append(t)
flat = torch.cat(to_cat)
data = dict()
data["params"] = flat.numpy()
return data
def deserialized_model(self, m):
"""
......@@ -118,8 +129,13 @@ class Sharing:
"""
state_dict = dict()
for key, value in m.items():
state_dict[key] = torch.from_numpy(value)
T = m["params"]
start_index = 0
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i]))
start_index = end_index
return state_dict
def _pre_step(self):
......
......@@ -46,7 +46,6 @@ class Sharing:
self.dataset = dataset
self.communication_round = 0
self.log_dir = log_dir
self.total_data = 0
self.peer_deques = dict()
my_neighbors = self.graph.neighbors(self.uid)
......@@ -101,7 +100,6 @@ class Sharing:
m = dict()
for key, val in self.model.state_dict().items():
m[key] = val.numpy()
self.total_data += len(self.communication.encrypt(m[key]))
return m
def deserialized_model(self, m):
......
......@@ -72,7 +72,6 @@ class SubSampling(Sharing):
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
# self.random_seed_generator = torch.Generator()
# # Will use the random device if supported by CPU, else uses the system time
......@@ -216,12 +215,6 @@ class SubSampling(Sharing):
m["alpha"] = alpha
m["params"] = subsample.numpy()
# logging.info("Converted dictionary to json")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["seed"])) + len(
self.communication.encrypt(m["alpha"])
)
return m
def deserialized_model(self, m):
......
......@@ -46,7 +46,6 @@ class Synchronous:
self.dataset = dataset
self.communication_round = 0
self.log_dir = log_dir
self.total_data = 0
self.peer_deques = dict()
self.my_neighbors = self.graph.neighbors(self.uid)
......@@ -104,7 +103,6 @@ class Synchronous:
m = dict()
for key, val in self.model.state_dict().items():
m[key] = val - self.init_model[key] # this is -lr*gradient
self.total_data += len(self.communication.encrypt(m))
return m
def serialized_model(self):
......@@ -120,7 +118,6 @@ class Synchronous:
m = dict()
for key, val in self.model.state_dict().items():
m[key] = val.clone().detach()
self.total_data += len(self.communication.encrypt(m))
return m
def deserialized_model(self, m):
......
......@@ -68,7 +68,6 @@ class TopKParams(Sharing):
self.dict_ordered = dict_ordered
self.save_shared = save_shared
self.metadata_cap = metadata_cap
self.total_meta = 0
if self.save_shared:
# Only save for 2 procs: Save space
......@@ -171,10 +170,6 @@ class TopKParams(Sharing):
# m[key] = json.dumps(m[key])
logging.info("Converted dictionary to json")
self.total_data += len(self.communication.encrypt(m["params"]))
self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
self.communication.encrypt(m["offsets"])
)
return m
......
......@@ -181,7 +181,6 @@ class Wavelet(PartialModel):
if self.alpha >= self.metadata_cap: # Share fully
data = self.pre_share_model_transformed
m["params"] = data.numpy()
self.total_data += len(self.communication.encrypt(m["params"]))
if self.model.accumulated_changes is not None:
self.model.accumulated_changes = torch.zeros_like(
self.model.accumulated_changes
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment