Skip to content
Snippets Groups Projects
Commit 7575bd98 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Separate files shared_params

parent 3c3007a0
No related branches found
No related tags found
No related merge requests found
{
"0": "10.90.41.130",
"1": "10.90.41.131",
"2": "10.90.41.132",
"3": "10.90.41.133"
}
\ No newline at end of file
{
"0": "10.90.41.129",
"1": "10.90.41.130",
"2": "10.90.41.131",
"3": "10.90.41.132",
"4": "10.90.41.133"
}
\ No newline at end of file
import json import json
import logging import logging
import os import os
from pathlib import Path
import numpy import numpy
import torch import torch
...@@ -28,6 +29,10 @@ class PartialModel(Sharing): ...@@ -28,6 +29,10 @@ class PartialModel(Sharing):
self.alpha = alpha self.alpha = alpha
self.dict_ordered = dict_ordered self.dict_ordered = dict_ordered
self.communication_round = 0 self.communication_round = 0
self.folder_path = os.path.join(
self.log_dir, "shared_params/{}".format(self.rank)
)
Path(self.folder_path).mkdir(parents=True, exist_ok=True)
def extract_top_gradients(self): def extract_top_gradients(self):
logging.info("Summing up gradients") logging.info("Summing up gradients")
...@@ -48,26 +53,20 @@ class PartialModel(Sharing): ...@@ -48,26 +53,20 @@ class PartialModel(Sharing):
with torch.no_grad(): with torch.no_grad():
_, G_topk = self.extract_top_gradients() _, G_topk = self.extract_top_gradients()
if self.communication_round: shared_params = dict()
with open( shared_params["order"] = list(self.model.state_dict().keys())
os.path.join( shapes = dict()
self.log_dir, "{}_shared_params.json".format(self.rank) for k, v in self.model.state_dict().items():
), shapes[k] = list(v.shape)
"r", shared_params["shapes"] = shapes
) as inf:
shared_params = json.load(inf)
else:
shared_params = dict()
shared_params["order"] = list(self.model.state_dict().keys())
shapes = dict()
for k, v in self.model.state_dict().items():
shapes[k] = list(v.shape)
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist() shared_params[self.communication_round] = G_topk.tolist()
with open( with open(
os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)), os.path.join(
self.folder_path,
"{}_shared_params.json".format(self.communication_round + 1),
),
"w", "w",
) as of: ) as of:
json.dump(shared_params, of) json.dump(shared_params, of)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment