Skip to content
Snippets Groups Projects
Commit 18c957e9 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Partial model sharing FAST

parent 6ec0bd26
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
......@@ -61,6 +61,6 @@ if __name__ == "__main__":
args.iterations,
args.log_dir,
log_level[args.log_level],
args.test_after
args.test_after,
],
)
......@@ -198,21 +198,21 @@ class Node:
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(log_dir, "train_loss.png"),
os.path.join(log_dir, "{}_train_loss.png".format(self.rank)),
)
self.save_plot(
self.test_loss,
"test_loss",
"Testing Loss",
"Communication Rounds",
os.path.join(log_dir, "test_loss.png"),
os.path.join(log_dir, "{}_test_loss.png".format(self.rank)),
)
self.save_plot(
self.test_acc,
"test_acc",
"Testing Accuracy",
"Communication Rounds",
os.path.join(log_dir, "test_acc.png"),
os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
)
self.communication.disconnect_neighbors()
......@@ -9,64 +9,94 @@ from decentralizepy.sharing.Sharing import Sharing
class PartialModel(Sharing):
def __init__(
self, rank, machine_id, communication, mapping, graph, model, dataset, alpha=1.0
self,
rank,
machine_id,
communication,
mapping,
graph,
model,
dataset,
alpha=1.0,
dict_ordered=True,
):
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset
)
self.alpha = alpha
self.dict_ordered = dict_ordered
def extract_sorted_gradients(self):
def extract_top_gradients(self):
logging.info("Summing up gradients")
assert len(self.model.accumulated_gradients) > 0
gradient_sum = self.model.accumulated_gradients[0]
for i in range(1, len(self.model.accumulated_gradients)):
for key in self.model.accumulated_gradients[i]:
gradient_sum[key] += self.model.accumulated_gradients[i][key]
gradient_sequence = []
logging.info("Turning gradients into tuples")
logging.info("Returning topk gradients")
tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
return torch.topk(
G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
)
for key, gradient in gradient_sum.items():
for index, val in enumerate(torch.flatten(gradient)):
gradient_sequence.append((val, key, index))
def serialized_model(self):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
logging.info("Extracting topk params")
logging.info("Sorting gradient tuples")
tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
T = torch.cat(tensors_to_cat, dim=0)
T_topk = T[G_topk]
gradient_sequence.sort() # bottleneck
return gradient_sequence
logging.info("Generating dictionary to send")
def serialized_model(self):
gradient_sequence = self.extract_sorted_gradients()
logging.info("Extracted sorted gradients")
gradient_sequence = gradient_sequence[
: round(len(gradient_sequence) * self.alpha)
]
m = dict()
m = dict()
for _, key, index in gradient_sequence:
if key not in m:
m[key] = []
m[key].append(
(
index,
torch.flatten(self.model.state_dict()[key])[index].numpy().tolist(),
)
)
if not self.dict_ordered:
raise NotImplementedError
logging.info("Generated dictionary to send")
m["indices"] = G_topk.numpy().tolist()
m["params"] = T_topk.numpy().tolist()
for key in m:
m[key] = json.dumps(m[key])
assert len(m["indices"]) == len(m["params"])
logging.info("Elements sending: {}".format(len(m["indices"])))
logging.info("Converted dictionary to json")
logging.info("Generated dictionary to send")
return m
for key in m:
m[key] = json.dumps(m[key])
def deserialized_model(self, m):
state_dict = self.model.state_dict()
logging.info("Converted dictionary to json")
return m
for key, value in m.items():
for index, param_val in json.loads(value):
torch.flatten(state_dict[key])[index] = param_val
return state_dict
def deserialized_model(self, m):
with torch.no_grad():
state_dict = self.model.state_dict()
if not self.dict_ordered:
raise NotImplementedError
shapes = []
lens = []
tensors_to_cat = []
for _, v in state_dict.items():
shapes.append(v.shape)
t = v.flatten()
lens.append(t.shape[0])
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(json.loads(m["indices"]))
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(json.loads(m["params"]))
logging.debug("Final tensor: {}".format(T[index_tensor]))
start_index = 0
for i, key in enumerate(state_dict):
end_index = start_index + lens[i]
state_dict[key] = T[start_index:end_index].reshape(shapes[i])
start_index = end_index
return state_dict
......@@ -28,7 +28,7 @@ def get_args():
parser.add_argument("-ll", "--log_level", type=str, default="INFO")
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
parser.add_argument("-ta", "--test_after", type=int, default = 5)
parser.add_argument("-ta", "--test_after", type=int, default=5)
args = parser.parse_args()
return args
......@@ -45,7 +45,7 @@ def write_args(args, path):
"log_level": args.log_level,
"graph_file": args.graph_file,
"graph_type": args.graph_type,
"test_after": args.test_after
"test_after": args.test_after,
}
with open(os.path.join(path, "args.json"), "w") as of:
json.dump(data, of)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment