Skip to content
Snippets Groups Projects
Commit 5985bb90 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Partial Model multi-machine

parent 1effbc35
No related branches found
No related tags found
No related merge requests found
File moved
......@@ -11,12 +11,12 @@ sizes =
[OPTIMIZER_PARAMS]
optimizer_package = torch.optim
optimizer_class = Adam
lr = 0.01
lr = 0.001
[TRAIN_PARAMS]
training_package = decentralizepy.training.GradientAccumulator
training_class = GradientAccumulator
epochs_per_round = 3
epochs_per_round = 5
batch_size = 1024
shuffle = True
loss_package = torch.nn
......@@ -30,4 +30,3 @@ addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.PartialModel
sharing_class = PartialModel
alpha = 1.0
\ No newline at end of file
......@@ -54,8 +54,9 @@ class TCP(Communication):
return sender, data
def connect_neighbors(self, neighbors):
logging.info("Sending connection request to neighbors")
for uid in neighbors:
logging.info("Connecting to my neighbour: {}".format(uid))
logging.debug("Connecting to my neighbour: {}".format(uid))
id = str(uid).encode()
req = self.context.socket(zmq.DEALER)
req.setsockopt(zmq.IDENTITY, self.identity)
......@@ -68,10 +69,10 @@ class TCP(Communication):
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.info("Received {} from {}".format(HELLO, sender))
logging.debug("Received {} from {}".format(HELLO, sender))
self.barrier.add(sender)
elif recv == BYE:
logging.info("Received {} from {}".format(BYE, sender))
logging.debug("Received {} from {}".format(BYE, sender))
raise RuntimeError(
"A neighbour wants to disconnect before training started!"
)
......@@ -82,6 +83,8 @@ class TCP(Communication):
self.peer_deque.append(self.decrypt(sender, recv))
logging.info("Connected to all neighbors")
def receive(self):
if len(self.peer_deque) != 0:
resp = self.peer_deque[0]
......@@ -91,12 +94,12 @@ class TCP(Communication):
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.info("Received {} from {}".format(HELLO, sender))
logging.debug("Received {} from {}".format(HELLO, sender))
raise RuntimeError(
"A neighbour wants to connect when everyone is connected!"
)
elif recv == BYE:
logging.info("Received {} from {}".format(BYE, sender))
logging.debug("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
return self.receive()
else:
......@@ -107,17 +110,18 @@ class TCP(Communication):
to_send = self.encrypt(data)
id = str(uid).encode()
self.peer_sockets[id].send(to_send)
logging.info("{} sent the message to {}.".format(self.uid, uid))
logging.debug("{} sent the message to {}.".format(self.uid, uid))
def disconnect_neighbors(self):
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for sock in self.peer_sockets.values():
sock.send(BYE)
self.sent_disconnections = True
while len(self.barrier):
sender, recv = self.router.recv_multipart()
if recv == BYE:
logging.info("Received {} from {}".format(BYE, sender))
logging.debug("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
else:
logging.critical(
......
......@@ -252,7 +252,6 @@ class Femnist(Dataset):
plt.show()
def test(self, model, loss):
logging.debug("Evaluating on test set.")
testloader = self.get_testset()
logging.debug("Test Loader instantiated.")
......@@ -279,7 +278,7 @@ class Femnist(Dataset):
total_pred[label] += 1
total_predicted += 1
logging.info("Predicted on the test set")
logging.debug("Predicted on the test set")
for key, value in enumerate(correct_pred):
if total_pred[key] != 0:
......@@ -291,7 +290,6 @@ class Femnist(Dataset):
accuracy = 100 * float(total_correct) / total_predicted
loss_val = loss_val / count
logging.info("Overall accuracy is: {:.1f} %".format(accuracy))
logging.info("Evaluating complete.")
return accuracy, loss_val
......
import importlib
import json
import logging
import os
......@@ -17,7 +18,9 @@ class Node:
def save_plot(self, l, label, title, xlabel, filename):
plt.clf()
plt.plot(l, label=label)
x_axis = l.keys()
y_axis = [l[key] for key in x_axis]
plt.plot(x_axis, y_axis, label=label)
plt.xlabel(xlabel)
plt.title(title)
plt.savefig(filename)
......@@ -168,9 +171,9 @@ class Node:
self.testset = self.dataset.get_testset()
rounds_to_test = test_after
self.train_loss = []
self.test_loss = []
self.test_acc = []
self.train_loss = dict()
self.test_loss = dict()
self.test_acc = dict()
for iteration in range(iterations):
logging.info("Starting training iteration: %d", iteration)
......@@ -183,36 +186,44 @@ class Node:
self.trainer.reset_optimizer(self.optimizer)
loss_after_sharing = self.trainer.eval_loss(self.dataset)
self.train_loss.append(loss_after_sharing)
self.train_loss[iteration + 1] = loss_after_sharing
rounds_to_test -= 1
if self.dataset.__testing__ and rounds_to_test == 0:
logging.info("Evaluating on test set.")
rounds_to_test = test_after
ta, tl = self.dataset.test(self.model, self.loss)
self.test_acc.append(ta)
self.test_loss.append(tl)
self.save_plot(
self.train_loss,
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_train_loss.png".format(self.rank)),
)
self.save_plot(
self.test_loss,
"test_loss",
"Testing Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_test_loss.png".format(self.rank)),
)
self.save_plot(
self.test_acc,
"test_acc",
"Testing Accuracy",
"Communication Rounds",
os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
)
self.test_acc[iteration + 1] = ta
self.test_loss[iteration + 1] = tl
self.save_plot(
self.train_loss,
"train_loss",
"Training Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_train_loss.png".format(self.rank)),
)
self.save_plot(
self.test_loss,
"test_loss",
"Testing Loss",
"Communication Rounds",
os.path.join(log_dir, "{}_test_loss.png".format(self.rank)),
)
self.save_plot(
self.test_acc,
"test_acc",
"Testing Accuracy",
"Communication Rounds",
os.path.join(log_dir, "{}_test_acc.png".format(self.rank)),
)
with open(os.path.join(log_dir, "{}_train_loss.json"), "w") as of:
json.dump(self.train_loss, of)
with open(os.path.join(log_dir, "{}_test_loss.json"), "w") as of:
json.dump(self.test_loss, of)
with open(os.path.join(log_dir, "{}_test_acc.json"), "w") as of:
json.dump(self.test_acc, of)
self.communication.disconnect_neighbors()
......@@ -57,13 +57,14 @@ class Sharing:
for neighbor in iter_neighbors:
self.communication.send(neighbor, data)
logging.info("Waiting for messages from neighbors")
while not self.received_from_all():
sender, data = self.communication.receive()
logging.info("Received model from {}".format(sender))
logging.debug("Received model from {}".format(sender))
degree = data["degree"]
del data["degree"]
self.peer_deques[sender].append((degree, self.deserialized_model(data)))
logging.info("Deserialized received model from {}".format(sender))
logging.debug("Deserialized received model from {}".format(sender))
logging.info("Starting model averaging after receiving from all neighbors")
total = dict()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment