Newer
Older
from matplotlib import pyplot as plt
import numpy as np
from time import time
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Mapping import Mapping
class Node:
"""
This class defines the node (entity that performs learning, sharing and communication).
def save_plot(self, l, label, title, xlabel, filename):
"""
Save Matplotlib plot. Clears previous plots.
Parameters
----------
l : dict
dict of x -> y. `x` must be castable to int.
label : str
label of the plot. Used for legend.
title : str
Header
xlabel : str
x-axis label
filename : str
Name of file to save the plot as.
"""
y_axis = [l[key] for key in l.keys()]
x_axis = list(map(int, l.keys()))
plt.xlabel(xlabel)
plt.title(title)
plt.savefig(filename)
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
log_file = os.path.join(log_dir, str(rank) + ".log")
logging.basicConfig(
filename=log_file,
format="[%(asctime)s][%(module)s][%(levelname)s] %(message)s",
level=log_level,
force=True,
)
def cache_fields(
self, rank, machine_id, mapping, graph, iterations, log_dir, test_after
):
"""
Instantiate object field with arguments.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
log_dir : str
Logging directory
"""
self.log_dir = log_dir
self.iterations = iterations
self.test_after = test_after
logging.debug("Rank: %d", self.rank)
logging.debug("type(graph): %s", str(type(self.rank)))
logging.debug("type(mapping): %s", str(type(self.mapping)))
def init_dataset_model(self, dataset_configs):
"""
Instantiate dataset and model from config.
Parameters
----------
dataset_configs : dict
Python dict containing dataset config params
"""
dataset_module = importlib.import_module(dataset_configs["dataset_package"])
self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
self.dataset_params = utils.remove_keys(
dataset_configs, ["dataset_package", "dataset_class", "model_class"]
)
self.dataset = self.dataset_class(
self.rank, self.machine_id, self.mapping, **self.dataset_params
logging.info("Dataset instantiation complete.")
self.model_class = getattr(dataset_module, dataset_configs["model_class"])
self.model = self.model_class()
def init_optimizer(self, optimizer_configs):
"""
Instantiate optimizer from config.
Parameters
----------
optimizer_configs : dict
Python dict containing optimizer config params
"""
optimizer_module = importlib.import_module(
optimizer_configs["optimizer_package"]
)
optimizer_module, optimizer_configs["optimizer_class"]
)
self.optimizer_params = utils.remove_keys(
optimizer_configs, ["optimizer_package", "optimizer_class"]
)
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
)
def init_trainer(self, train_configs):
"""
Instantiate training module and loss from config.
Parameters
----------
train_configs : dict
Python dict containing training config params
"""
train_module = importlib.import_module(train_configs["training_package"])
train_class = getattr(train_module, train_configs["training_class"])
loss_package = importlib.import_module(train_configs["loss_package"])
if "loss_class" in train_configs.keys():
loss_class = getattr(loss_package, train_configs["loss_class"])
self.loss = getattr(loss_package, train_configs["loss"])
train_params = utils.remove_keys(
train_configs,
[
"training_package",
"training_class",
"loss",
"loss_package",
"loss_class",
],
)
self.rank,
self.machine_id,
self.mapping,
self.model,
self.optimizer,
self.loss,
self.log_dir,
**train_params
def init_comm(self, comm_configs):
"""
Instantiate communication module from config.
Parameters
----------
comm_configs : dict
Python dict containing communication config params
"""
comm_module = importlib.import_module(comm_configs["comm_package"])
comm_class = getattr(comm_module, comm_configs["comm_class"])
comm_params = utils.remove_keys(comm_configs, ["comm_package", "comm_class"])
self.communication = comm_class(
self.rank, self.machine_id, self.mapping, self.graph.n_procs, **comm_params
)
def init_sharing(self, sharing_configs):
"""
Instantiate sharing module from config.
Parameters
----------
sharing_configs : dict
Python dict containing sharing config params
"""
sharing_package = importlib.import_module(sharing_configs["sharing_package"])
sharing_class = getattr(sharing_package, sharing_configs["sharing_class"])
sharing_params = utils.remove_keys(
sharing_configs, ["sharing_package", "sharing_class"]
)
self.sharing = sharing_class(
self.rank,
self.machine_id,
self.communication,
self.mapping,
self.graph,
self.model,
self.dataset,
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def instantiate(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
test_after=5,
*args
):
"""
Construct objects.
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations.
log_dir : str
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
"""
logging.info("Started process.")
self.cache_fields(
rank, machine_id, mapping, graph, iterations, log_dir, test_after
)
self.init_log(log_dir, rank, log_level)
self.init_dataset_model(config["DATASET"])
self.init_optimizer(config["OPTIMIZER_PARAMS"])
self.init_trainer(config["TRAIN_PARAMS"])
self.init_comm(config["COMMUNICATION"])
self.init_sharing(config["SHARING"])
def run(self):
"""
Start the decentralized learning
self.communication.connect_neighbors(self.graph.neighbors(self.uid))
for iteration in range(self.iterations):
logging.info("Starting training iteration: %d", iteration)
self.optimizer = self.optimizer_class(
self.model.parameters(), **self.optimizer_params
self.trainer.reset_optimizer(self.optimizer)
loss_after_sharing = self.trainer.eval_loss(self.dataset)
if iteration:
with open(
os.path.join(self.log_dir, "{}_results.json".format(self.rank)),
"r",
) as inf:
results_dict = json.load(inf)
else:
results_dict = {
"train_loss": {},
"test_loss": {},
"test_acc": {},
"total_bytes": {},
"total_meta": {},
"total_data_per_n": {},
"grad_mean": {},
"grad_std": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
results_dict["total_meta"][iteration + 1] = self.sharing.total_meta
results_dict["total_data_per_n"][
iteration + 1
] = self.sharing.total_data
results_dict["grad_mean"][iteration + 1] = self.sharing.mean
results_dict["grad_std"][iteration + 1] = self.sharing.std
self.save_plot(
results_dict["train_loss"],
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_train_loss.png".format(self.rank)),
logging.info("Evaluating on train set.")
# Uses its own trainset, iterates over it in its entirety
loss_after_sharing = self.trainer.eval_loss(self.dataset)
results_dict["train_loss"][iteration + 1] = loss_after_sharing
ta, tl = self.dataset.test(self.model, self.loss)
results_dict["test_acc"][iteration + 1] = ta
results_dict["test_loss"][iteration + 1] = tl
"test_loss",
"Testing Loss",
"Communication Rounds",
os.path.join(self.log_dir, "{}_test_loss.png".format(self.rank)),
"test_acc",
"Testing Accuracy",
"Communication Rounds",
os.path.join(self.log_dir, "{}_test_acc.png".format(self.rank)),
os.path.join(self.log_dir, "{}_results.json".format(self.rank)), "w"
logging.info("All neighbors disconnected. Process complete!")
def __init__(
self,
rank: int,
machine_id: int,
mapping: Mapping,
graph: Graph,
config,
iterations=1,
log_dir=".",
log_level=logging.INFO,
test_after=5,
*args
):
"""
Constructor
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
Parameters
----------
rank : int
Rank of process local to the machine
machine_id : int
Machine ID on which the process in running
mapping : decentralizepy.mappings
The object containing the mapping rank <--> uid
graph : decentralizepy.graphs
The object containing the global graph
config : dict
A dictionary of configurations. Must contain the following:
[DATASET]
dataset_package
dataset_class
model_class
[OPTIMIZER_PARAMS]
optimizer_package
optimizer_class
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 25
batch_size = 64
log_dir : str
Logging directory
log_level : logging.Level
One of DEBUG, INFO, WARNING, ERROR, CRITICAL
args : optional
Other arguments
total_threads = os.cpu_count()
threads_per_proc = max(math.floor(total_threads / mapping.procs_per_machine), 1)
torch.set_num_threads(threads_per_proc)
torch.set_num_interop_threads(1)
self.instantiate(
rank,
machine_id,
mapping,
graph,
config,
iterations,
log_dir,
log_level,
test_after,
*args
)
logging.info(
"Each proc uses %d threads out of %d.", threads_per_proc, total_threads
)