Choco compression fix
Compare changes
Files
2@@ -9,13 +9,10 @@ import torch
@@ -37,6 +34,19 @@ def flatten_state_dict(state_dict):
@@ -69,20 +79,28 @@ def unflatten_state_dict(flat_tensor, reference_state_dict):
@@ -93,7 +111,7 @@ class SecureCompressedAggregation(DPSGDNode):
@@ -106,9 +124,6 @@ class SecureCompressedAggregation(DPSGDNode):
@@ -129,10 +144,47 @@ class SecureCompressedAggregation(DPSGDNode):
@@ -141,32 +193,57 @@ class SecureCompressedAggregation(DPSGDNode):
@@ -178,38 +255,61 @@ class SecureCompressedAggregation(DPSGDNode):
@@ -221,21 +321,53 @@ class SecureCompressedAggregation(DPSGDNode):
@@ -243,16 +375,53 @@ class SecureCompressedAggregation(DPSGDNode):