Skip to content
Snippets Groups Projects
Commit b2fe436a authored by Jeffrey Wigger's avatar Jeffrey Wigger
Browse files

sort of working now

parent c5e06f26
No related branches found
No related tags found
No related merge requests found
......@@ -24,11 +24,11 @@ loss_package = torch.nn
loss_class = CrossEntropyLoss
[COMMUNICATION]
comm_package = decentralizepy.communication.TCPRandomWalkRouting
comm_class = TCPRandomWalkRouting
comm_package = decentralizepy.communication.TCPRandomWalk
comm_class = TCPRandomWalk
addresses_filepath = ip_addr_6Machines.json
sampler = equi
[SHARING]
sharing_package = decentralizepy.sharing.SharingWithRWAsyncDynamic
sharing_class = SharingWithRWAsyncDynamic
\ No newline at end of file
sharing_package = decentralizepy.sharing.DPSGDRW
sharing_class = DPSGDRW
\ No newline at end of file
import faulthandler
import importlib
import json
import logging
import lzma
......@@ -6,13 +6,15 @@ import pickle
import time
import weakref
from collections import deque
from ctypes import c_int
from ctypes import c_int, c_long
from multiprocessing.sharedctypes import Value
from queue import Empty
import traceback
from multiprocessing import Lock
import torch
import zmq
import faulthandler
faulthandler.enable()
import torch.multiprocessing as mp
......@@ -49,8 +51,65 @@ class TCPRandomWalkBase(Communication):
"""
machine_addr = self.ip_addrs[str(machine_id)]
port = rank + 45000
port = rank + self.offset
return "tcp://{}:{}".format(machine_addr, port)
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
addresses_filepath : str
JSON file with machine_id -> ip mapping
compression_package : str
Import path of a module that implements the compression.Compression.Compression class
compression_class : str
Name of the compression class inside the compression package
"""
super().__init__(rank, machine_id, mapping, total_procs)
self.addresses_filepath = addresses_filepath
self.compress = compress
self.offset = 10000 + offset
self.compression_package = compression_package
self.compression_class = compression_class
with open(addresses_filepath) as addrs:
self.ip_addrs = json.load(addrs)
self.identity = str(self.uid).encode()
if compression_package and compression_class:
compressor_module = importlib.import_module(compression_package)
compressor_class = getattr(compressor_module, compression_class)
self.compressor = compressor_class()
logging.info(f"Using the {compressor_class} to compress the data")
else:
assert not self.compress
self.total_data = 0 # dummy values
self.total_meta = 0
def encrypt(self, data):
"""
......@@ -67,12 +126,29 @@ class TCPRandomWalkBase(Communication):
Encoded data
"""
logging.debug("in encrypt")
if self.compress:
compressor = lzma.LZMACompressor()
output = compressor.compress(pickle.dumps(data)) + compressor.flush()
else:
logging.debug("in encrypt: compress")
if "indices" in data:
data["indices"] = self.compressor.compress(data["indices"])
assert "params" in data
data["params"] = self.compressor.compress_float(data["params"])
data_len = len(pickle.dumps(data["params"]))
output = pickle.dumps(data)
# the compressed meta data gets only a few bytes smaller after pickling
self.total_meta.value += (len(output) - data_len)
self.total_data.value += data_len
else:
logging.debug("in encrypt: else")
output = pickle.dumps(data)
# centralized testing uses its own instance
if type(data) == dict:
assert "params" in data
data_len = len(pickle.dumps(data["params"]))
self.total_meta.value += (len(output) - data_len)
self.total_data.value += data_len
return output
def decrypt(self, sender, data):
......@@ -92,10 +168,17 @@ class TCPRandomWalkBase(Communication):
(sender: int, data: dict)
"""
logging.debug("in decrypt")
sender = int(sender.decode())
if self.compress:
data = pickle.loads(lzma.decompress(data))
logging.debug("in decrypt:comp")
data = pickle.loads(data)
if "indices" in data:
data["indices"] = self.compressor.decompress(data["indices"])
if "params" in data:
data["params"] = self.compressor.decompress_float(data["params"])
else:
logging.debug("in decrypt:else")
data = pickle.loads(data)
return sender, data
......@@ -115,6 +198,9 @@ class TCPRandomWalk(TCPRandomWalkBase):
total_procs,
addresses_filepath,
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
sampler="equi",
):
"""
......@@ -132,18 +218,15 @@ class TCPRandomWalk(TCPRandomWalkBase):
Total number of processes
"""
self.total_procs = total_procs
self.rank = rank
self.machine_id = machine_id
self.mapping = mapping
self.addresses_filepath = addresses_filepath
self.uid = mapping.get_uid(rank, machine_id)
self.compress = compress
super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class)
self.sampler = sampler
self.total_bytes = 0
self.send_queue = mp.Queue(1000)
self.recv_queue = mp.Queue(1000)
# Since we are only adding these do not need to share the same lock, intermediary results may be inconsistent
self.lock = Lock()
self.total_data = Value(c_long, 0, lock = self.lock)
self.total_meta = Value(c_long, 0, lock = self.lock)
self.total_bytes = Value(c_long, 0, lock = self.lock)
self.flag_running = Value(c_int, 0, lock=False)
def connect_neighbors(self, neighbors):
......@@ -170,7 +253,13 @@ class TCPRandomWalk(TCPRandomWalkBase):
self.recv_queue,
self.flag_running,
neighbors,
self.total_data,
self.total_meta,
self.total_bytes,
self.compress,
self.offset,
self.compression_package,
self.compression_class,
self.sampler,
),
start_method="fork",
......@@ -187,13 +276,14 @@ class TCPRandomWalk(TCPRandomWalkBase):
Received and decrypted data
"""
# logging.debug("Receive in TCPRandomWalk")
logging.debug("Receive in TCPRandomWalk")
try:
return self.recv_queue.get(block=block, timeout=None) # already decrypted
logging.debug("Receive in TCPRandomWalk; post get")
except Empty:
return None, None
def send(self, uid, data):
def send(self, uid, data, encrypt=True):
"""
Send a message to a process.
......@@ -207,16 +297,21 @@ class TCPRandomWalk(TCPRandomWalkBase):
"""
if uid is not None:
logging.debug("Send to %i in TCPRandomWalk", uid)
self.send_queue.put((uid, data))
self.send_queue.put((uid, data, encrypt))
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
"""
print("disconnect_neighbors")
self.flag_running.value = 0
del self.lock
self.send_queue.close() # this crashes
self.recv_queue.close()
self.ctx.join()
# self.send_queue.close() # TODO: is this needed
print("disconnect_neighbors: joined")
# self.send_queue.close() # this crashes
# self.recv_queue.close()
......@@ -232,7 +327,13 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
recv_queue: mp.Queue,
flag_running,
neighbors,
total_data,
total_meta,
total_bytes,
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
sampler="equi",
):
"""
......@@ -252,57 +353,61 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
JSON file with machine_id -> ip mapping
"""
super().__init__(rank, machine_id, mapping, total_procs)
with open(addresses_filepath) as addrs:
self.ip_addrs = json.load(addrs)
self.total_procs = total_procs
self.rank = rank
self.machine_id = machine_id
self.mapping = mapping
self.uid = mapping.get_uid(rank, machine_id)
self.compress = compress
self.sampler = sampler
self.identity = str(self.uid).encode()
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.bind(self.addr(rank, machine_id))
self.sent_disconnections = False
self.compress = compress
self.send_queue = send_queue
self.recv_queue = recv_queue
self.flag_running = flag_running
self.neighbors = neighbors
self.random_generator = torch.Generator()
self.rw_seed = (
self.random_generator.seed()
) # new random seed from the random device
logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed)
self.rw_messages_stat = []
self.rw_double_count_stat = []
if self.sampler == "equi":
logging.info("rw_samper is rw_sampler_equi")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
elif self.sampler == "equi_check_history":
logging.info("rw_samper is rw_sampler_equi_check_history")
self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history)
else:
logging.info("rw_samper is rw_sampler_equi (default)")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
self.peer_sockets = dict()
self.barrier = set()
self.connect_neighbors(self.neighbors)
super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset,
compression_package, compression_class)
print("post super")
try:
self.sampler = sampler
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.bind(self.addr(rank, machine_id))
self.sent_disconnections = False
self.compress = compress
self.total_data = total_data
self.total_meta = total_meta
self.total_bytes = total_bytes
self.send_queue = send_queue
self.recv_queue = recv_queue
self.flag_running = flag_running
self.neighbors = neighbors
self.random_generator = torch.Generator()
self.rw_seed = (
self.random_generator.seed()
) # new random seed from the random device
logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed)
self.rw_messages_stat = []
self.rw_double_count_stat = []
if self.sampler == "equi":
logging.info("rw_samper is rw_sampler_equi")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
elif self.sampler == "equi_check_history":
logging.info("rw_samper is rw_sampler_equi_check_history")
self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history)
else:
logging.info("rw_samper is rw_sampler_equi (default)")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
self.peer_sockets = dict()
self.barrier = set()
self.connect_neighbors(self.neighbors)
self.comm_loop()
except BaseException as e:
error_message = traceback.format_exc()
print(error_message)
logging.debug("GOT EXCEPTION")
logging.debug(error_message)
self.comm_loop()
self.recv_queue.close()
self.send_queue.close()
print("end")
def __del__(self):
"""
......@@ -311,9 +416,10 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
"""
self.context.destroy(linger=0)
logging.info(f"TCPRandomWalkInternal for {self.uid} is destroyed")
# exit(1)
# uncomment during debugging
self.recv_queue.close()
self.send_queue.close()
#self.recv_queue.close()
#self.send_queue.close()
def connect_neighbors(self, neighbors):
"""
......@@ -380,13 +486,13 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
sleep_time *= 1.1
try:
uid, data = self.send_queue.get_nowait()
uid, data, compress = self.send_queue.get_nowait()
# send may block if send queue is full
if uid is not None:
logging.debug("comm_loop will send to %i", uid)
else:
logging.debug("comm_loop will send rw")
self.send(uid, data)
self.send(uid, data, compress)
successes += 5
flushed = False
except Empty:
......@@ -462,7 +568,7 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
logging.debug(
"Forward rw {} to {}".format(data["visited"], new_neighbor)
)
self.send_queue.put((new_neighbor, data))
self.send_queue.put((new_neighbor, data, True))
else:
# the message has no more fuel so it is dropped
logging.info("dropped rw message with fuel %i ", data["fuel"])
......@@ -471,7 +577,7 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
"Received message from {} after putting into the queue".format(sender)
)
def send(self, uid, data):
def send(self, uid, data, encrypt = True):
"""
Send a message to a process.
......@@ -483,12 +589,18 @@ class TCPRandomWalkInternal(TCPRandomWalkBase):
Message as a Python dictionary
"""
logging.debug("send: rw? {}".format(data.get("rw", False)))
assert self.initialized == True
to_send = self.encrypt(data)
if encrypt:
rw = data.get("rw", False)
logging.debug("send: rw? {}".format(rw))
to_send = self.encrypt(data)
else:
rw = False
logging.debug("send: rw? {}".format(rw))
to_send = data
data_size = len(to_send)
self.total_bytes += data_size
if uid is None and data.get("rw", False): # a rw message
self.total_bytes.value += data_size
if uid is None and rw: # a rw message
if self.flag_running.value == 1: # Do not send rw if we are shutting down
uid = self.rw_sampler()(data)
logging.debug("send: rw to {}".format(uid))
......
......@@ -3,6 +3,7 @@ import json
import logging
import math
import os
import multiprocessing
import torch
from matplotlib import pyplot as plt
......@@ -432,17 +433,21 @@ class Node:
"grad_mean": {},
"grad_std": {},
}
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if type(self.communication.total_bytes) == multiprocessing.sharedctypes.Synchronized:
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes.value
else:
results_dict["total_bytes"][iteration + 1] = self.communication.total_bytes
if hasattr(self.communication, "total_meta"):
results_dict["total_meta"][
iteration + 1
] = self.communication.total_meta
if type(self.communication.total_meta) == multiprocessing.sharedctypes.Synchronized:
results_dict["total_meta"][iteration + 1] = self.communication.total_meta.value
else:
results_dict["total_meta"][iteration + 1] = self.communication.total_meta
if hasattr(self.communication, "total_data"):
results_dict["total_data_per_n"][
iteration + 1
] = self.communication.total_data
if type(self.communication.total_data) == multiprocessing.sharedctypes.Synchronized:
results_dict["total_data_per_n"][iteration + 1] = self.communication.total_data.value
else:
results_dict["total_data_per_n"][iteration + 1] = self.communication.total_data
if hasattr(self.sharing, "mean"):
results_dict["grad_mean"][iteration + 1] = self.sharing.mean
if hasattr(self.sharing, "std"):
......@@ -504,6 +509,7 @@ class Node:
"w",
) as of:
json.dump(self.model.shared_parameters_counter.numpy().tolist(), of)
logging.info("disconnect neighbors")
self.communication.disconnect_neighbors()
logging.info("Storing final weight")
self.model.dump_weights(self.weights_store_dir, self.uid, iteration)
......
......@@ -11,7 +11,11 @@ from decentralizepy.sharing.Sharing import Sharing
class DPSGDRW(Sharing):
"""
API defining who to share with and what, and what to do on receiving
Alternative implementation of DPSGD. Here, we send the weights before the gradients are calculated (at the beginning of the round, which is
equivalent to it being at the end of the round).
In _post_step it also has the option to send RW messages.
This implementation will only work together with TCPRandomWalk.
"""
......@@ -65,49 +69,9 @@ class DPSGDRW(Sharing):
self.rw_double_count_stat = []
self.rw_length = rw_length
with torch.no_grad():
self.init_model = {}
for k, v in self.model.state_dict().items():
self.init_model[k] = v.clone().detach()
self.number_of_neighbors = len(self.my_neighbors)
def serialized_model(self, model=None):
"""
Convert model to a dictionary. Here we can choose how much to share
Returns
-------
dict
Model converted to dict
"""
m = dict()
if model is None:
model = self.model.state_dict()
for key, val in model.items():
m[key] = val.clone().detach()
# self.total_data += len(self.communication.encrypt(m[key])) TODO: need to count this per link
data = {"data": m}
return data
def deserialized_model(self, m):
"""
Convert received dict to state_dict.
Parameters
----------
m : dict
received dict
Returns
-------
state_dict
state_dict of received
"""
return m["data"]
def _post_step(self):
"""
Called at the end of step.
......@@ -117,7 +81,7 @@ class DPSGDRW(Sharing):
def send():
# will have to send the data twice to make the code simpler (for the beginning)
rw_data = {
"data": self.serialized_model(self.init_model)["data"],
"params": self.init_model.numpy(),
"rw": True,
"degree": self.number_of_neighbors,
"iteration": self.communication_round,
......@@ -130,6 +94,7 @@ class DPSGDRW(Sharing):
rw_chance = self.rw_chance
while rw_chance >= 1.0:
# TODO: make sure they are not sent to the same neighbour
print("send RW")
send()
rw_chance -= 1
rw_now = torch.rand(size=(1,), generator=self.random_generator).item()
......@@ -152,6 +117,7 @@ class DPSGDRW(Sharing):
)
data_dict = self.deserialized_model(data)
is_rw = data.get("rw", False)
logging.debug(f"is rw{is_rw}")
model_data = data_dict
add = True
if is_rw:
......@@ -233,29 +199,31 @@ class DPSGDRW(Sharing):
total[key] = value * weight
if len(total) != 0:
for (
key,
value,
) in self.init_model.items(): # self.model.state_dict().items():
init_dict = self.deserialized_model({"params": self.init_model.numpy()})
for (key, value) in init_dict.items():
# self.model.state_dict().items():
total[key] += (1 - weight_total) * value # Metro-Hastings
# apply the update to the model
for key, value in self.model.state_dict().items():
# subtract the model change from the new average weight
total[key] = total[key] + (value - self.init_model[key])
total[key] = total[key] + (value - init_dict[key])
self.model.load_state_dict(total)
else:
logging.debug("Node did not receive nothing")
# The first round is completely a local update
self.init_model = {}
for k, v in self.model.state_dict().items():
self.init_model[k] = v.clone().detach()
to_cat = []
for _, v in self.model.state_dict().items():
vf = v.clone().detach().flatten()
to_cat.append(vf)
self.init_model = torch.cat(to_cat)
def step(self):
"""
Perform a sharing step. Implements D-PSGD.
"""
logging.info("--- COMMUNICATION ROUND {} ---".format(self.communication_round))
self._pre_step()
logging.info("Waiting for messages from neighbors")
if self.communication_round != 0:
......@@ -293,9 +261,10 @@ class DPSGDRW(Sharing):
iter_neighbors = self.get_neighbors(all_neighbors)
data["degree"] = len(all_neighbors)
data["iteration"] = self.communication_round
encrypted = self.communication.encrypt(data)
for i, neighbor in enumerate(iter_neighbors):
logging.debug("sending in DPSGDRWAsync to %i", neighbor)
self.communication.send(neighbor, data)
self.communication.send(neighbor, encrypted, False)
def __del__(self):
if len(self.rw_messages_stat) != 0 and len(self.rw_double_count_stat) != 0:
......
......@@ -12,7 +12,10 @@ from decentralizepy.sharing.DPSGDRW import DPSGDRW
class DPSGDRWAsync(DPSGDRW):
"""
API defining who to share with and what, and what to do on receiving
Alternative implementation of DPSGD. It is completely Async and based on DPSGDRW.
rw_chance is used to set the number of neighbours to which a RW is sent. It should be bigger than 1.
If rw_length is set to 1. Then it is like async DPSGD.
This implementation will only work together with TCPRandomWalk.
"""
......
......@@ -55,10 +55,13 @@ class Sharing:
self.shapes = []
self.lens = []
with torch.no_grad():
to_cat = []
for _, v in self.model.state_dict().items():
self.shapes.append(v.shape)
t = v.flatten().numpy()
self.lens.append(t.shape[0])
vf = v.clone().detach().flatten()
to_cat.append(vf)
self.lens.append(vf.shape[0])
self.init_model = torch.cat(to_cat)
def received_from_all(self):
"""
......
......@@ -9,9 +9,11 @@ import torch.multiprocessing as mp
from decentralizepy.sharing.DPSGDRW import DPSGDRW
class SharingWithRWAsyncDynamic(DPSGDRW):
class SharingDynamicGraph(DPSGDRW):
"""
API defining who to share with and what, and what to do on receiving
This implementation of Sharing uses RW to sample a new neighbour to connect to. The graph is dynamic. The neighbours
are bound to neighbor_bound = (min, max) However, it might temporary exceed the max.
This implementation will only work together with TCPRandomWalkRouting.
"""
......
......@@ -8,7 +8,10 @@ import torch
class SharingWithRW:
"""
API defining who to share with and what, and what to do on receiving
This implementation of Sharing with RW does synchronized rounds with its neighbour. On top of that it sends
RW messages in the same data package to a neighbour. Hence, this method does not work with encryption.
A random walk message will arrive at its destination with a lag of rw_length - 1.
This implementation is supposed to run with TCP.py
"""
......
......@@ -11,7 +11,9 @@ from decentralizepy.sharing.DPSGDRW import DPSGDRW
class SharingWithRWAsync(DPSGDRW):
"""
API defining who to share with and what, and what to do on receiving
This implementation of Sharing with RW does synchronized rounds with its neighbour. On top of that it runs
additional RW steps that are asynchronously sent at the _post_step.
This implementation will only work together with TCPRandomWalk.
"""
......@@ -120,10 +122,10 @@ class SharingWithRWAsync(DPSGDRW):
iter_neighbors = self.get_neighbors(all_neighbors)
data["degree"] = len(all_neighbors)
data["iteration"] = self.communication_round
encrypted = self.communication.encrypt(data)
for i, neighbor in enumerate(iter_neighbors):
logging.debug("sending in SharedWithRWAsync to %i", neighbor)
self.communication.send(neighbor, data)
self.communication.send(neighbor, encrypted, False)
logging.info("Waiting for messages from neighbors")
while not self.received_from_all():
......
......@@ -84,7 +84,7 @@ def get_args():
parser.add_argument("-tea", "--train_evaluate_after", type=int, default=1)
parser.add_argument("-ro", "--reset_optimizer", type=int, default=1)
parser.add_argument("-ctr", "--centralized_train_eval", type=int, default=0)
parser.add_argument("-cte", "--centralized_test_eval", type=int, default=1)
parser.add_argument("-cte", "--centralized_test_eval", type=int, default=0)
args = parser.parse_args()
return args
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment