Skip to content
Snippets Groups Projects
Commit 3dc5b174 authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

sharing works now with data compression

parent 786dcb98
No related branches found
No related tags found
No related merge requests found
......@@ -49,4 +49,4 @@ class EliasFpzip(Elias):
decompressed data as array
"""
return fpzip.decompress(bytes, order="C")
return fpzip.decompress(bytes, order="C").squeeze()
......@@ -49,4 +49,4 @@ class EliasFpzipLossy(Elias):
decompressed data as array
"""
return fpzip.decompress(bytes, order="C")
return fpzip.decompress(bytes, order="C").squeeze()
......@@ -52,6 +52,14 @@ class Sharing:
for n in self.my_neighbors:
self.peer_deques[n] = deque()
self.shapes = []
self.lens = []
with torch.no_grad():
for _, v in self.model.state_dict().items():
self.shapes.append(v.shape)
t = v.flatten().numpy()
self.lens.append(t.shape[0])
def received_from_all(self):
"""
Check if all neighbors have sent the current iteration
......@@ -95,11 +103,14 @@ class Sharing:
Model converted to dict
"""
m = dict()
for key, val in self.model.state_dict().items():
m[key] = val.numpy()
to_cat = []
with torch.no_grad():
for _, v in self.model.state_dict().items():
t = v.flatten()
to_cat.append(t)
flat = torch.cat(to_cat)
data = dict()
data["params"] = m
data["params"] = flat.numpy()
return data
def deserialized_model(self, m):
......@@ -118,8 +129,13 @@ class Sharing:
"""
state_dict = dict()
for key, value in m["params"].items():
state_dict[key] = torch.from_numpy(value)
T = m["params"]
start_index = 0
for i, key in enumerate(self.model.state_dict()):
end_index = start_index + self.lens[i]
state_dict[key] = torch.from_numpy(T[start_index:end_index].reshape(self.shapes[i]))
start_index = end_index
return state_dict
def _pre_step(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment