Skip to content
Snippets Groups Projects
Commit dd797778 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Fix writes of acc, loss; Log which params are shared

parent 5985bb90
No related branches found
No related tags found
No related merge requests found
......@@ -219,11 +219,17 @@ class Node:
os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
)
with open(os.path.join(log_dir, "{}_train_loss.json"), "w") as of:
with open(
os.path.join(log_dir, "{}_train_loss.json".format(self.rank)), "w"
) as of:
json.dump(self.train_loss, of)
with open(os.path.join(log_dir, "{}_test_loss.json"), "w") as of:
with open(
os.path.join(log_dir, "{}_test_loss.json".format(self.rank)), "w"
) as of:
json.dump(self.test_loss, of)
with open(os.path.join(log_dir, "{}_test_acc.json"), "w") as of:
with open(
os.path.join(log_dir, "{}_test_acc.json".format(self.rank)), "w"
) as of:
json.dump(self.test_acc, of)
self.communication.disconnect_neighbors()
import json
import logging
import os
import numpy
import torch
......@@ -17,14 +18,16 @@ class PartialModel(Sharing):
graph,
model,
dataset,
log_dir,
alpha=1.0,
dict_ordered=True,
):
super().__init__(
rank, machine_id, communication, mapping, graph, model, dataset
rank, machine_id, communication, mapping, graph, model, dataset, log_dir
)
self.alpha = alpha
self.dict_ordered = dict_ordered
self.communication_round = 0
def extract_top_gradients(self):
logging.info("Summing up gradients")
......@@ -44,6 +47,31 @@ class PartialModel(Sharing):
def serialized_model(self):
with torch.no_grad():
_, G_topk = self.extract_top_gradients()
if self.communication_round:
with open(
os.path.join(
self.log_dir, "{}_shared_params.json".format(self.rank)
),
"r",
) as inf:
shared_params = json.load(inf)
else:
shared_params = dict()
shared_params["order"] = self.model.state_dict().keys()
shapes = dict()
for k, v in self.model.state_dict.items():
shapes[k] = v.shape.tolist()
shared_params["shapes"] = shapes
shared_params[self.communication_round] = G_topk.tolist()
with open(
os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)),
"w",
) as of:
json.dump(shared_params, of)
logging.info("Extracting topk params")
tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
......
......@@ -11,7 +11,9 @@ class Sharing:
API defining who to share with and what, and what to do on receiving
"""
def __init__(self, rank, machine_id, communication, mapping, graph, model, dataset):
def __init__(
self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
):
self.rank = rank
self.machine_id = machine_id
self.uid = mapping.get_uid(rank, machine_id)
......@@ -20,6 +22,7 @@ class Sharing:
self.graph = graph
self.model = model
self.dataset = dataset
self.log_dir = log_dir
self.peer_deques = dict()
my_neighbors = self.graph.neighbors(self.uid)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment