import importlib import json import logging import lzma import pickle import time import weakref from collections import deque from ctypes import c_int, c_long from multiprocessing.sharedctypes import Value from queue import Empty import traceback from multiprocessing import Lock import torch import zmq import faulthandler faulthandler.enable() import torch.multiprocessing as mp from decentralizepy.communication.Communication import Communication HELLO = b"HELLO" BYE = b"BYE" class TCPRandomWalkBase(Communication): """ TCPRandomWalkBase that copies only the encrypt and decrypt functions from TCP.py This dummy is needed as the sharing interfaces can call encrpt and decrypt, so we need the functions both in TCPRandomWalk and TCPRandomWalkInternal """ def addr(self, rank, machine_id): """ Returns TCP address of the process. Parameters ---------- rank : int Local rank of the process machine_id : int Machine id of the process Returns ------- str Full address of the process using TCP """ machine_addr = self.ip_addrs[str(machine_id)] port = rank + self.offset return "tcp://{}:{}".format(machine_addr, port) def __init__( self, rank, machine_id, mapping, total_procs, addresses_filepath, compress=False, offset=2000, compression_package=None, compression_class=None, ): """ Constructor Parameters ---------- rank : int Local rank of the process machine_id : int Machine id of the process mapping : decentralizepy.mappings.Mapping uid, rank, machine_id invertible mapping total_procs : int Total number of processes addresses_filepath : str JSON file with machine_id -> ip mapping compression_package : str Import path of a module that implements the compression.Compression.Compression class compression_class : str Name of the compression class inside the compression package """ super().__init__(rank, machine_id, mapping, total_procs) self.addresses_filepath = addresses_filepath self.compress = compress self.offset = 10000 + offset self.compression_package = compression_package self.compression_class = compression_class with open(addresses_filepath) as addrs: self.ip_addrs = json.load(addrs) self.identity = str(self.uid).encode() if compression_package and compression_class: compressor_module = importlib.import_module(compression_package) compressor_class = getattr(compressor_module, compression_class) self.compressor = compressor_class() logging.info(f"Using the {compressor_class} to compress the data") else: assert not self.compress self.total_data = 0 # dummy values self.total_meta = 0 def encrypt(self, data): """ Encode data as python pickle. Parameters ---------- data : dict Data dict to send Returns ------- byte Encoded data """ logging.debug("in encrypt") if self.compress: logging.debug("in encrypt: compress") if "indices" in data: data["indices"] = self.compressor.compress(data["indices"]) assert "params" in data data["params"] = self.compressor.compress_float(data["params"]) data_len = len(pickle.dumps(data["params"])) output = pickle.dumps(data) # the compressed meta data gets only a few bytes smaller after pickling self.total_meta.value += (len(output) - data_len) self.total_data.value += data_len else: logging.debug("in encrypt: else") output = pickle.dumps(data) # centralized testing uses its own instance if type(data) == dict: assert "params" in data data_len = len(pickle.dumps(data["params"])) self.total_meta.value += (len(output) - data_len) self.total_data.value += data_len return output def decrypt(self, sender, data): """ Decode received pickle data. Parameters ---------- sender : byte sender of the data data : byte Data received Returns ------- tuple (sender: int, data: dict) """ logging.debug("in decrypt") sender = int(sender.decode()) if self.compress: logging.debug("in decrypt:comp") data = pickle.loads(data) if "indices" in data: data["indices"] = self.compressor.decompress(data["indices"]) if "params" in data: data["params"] = self.compressor.decompress_float(data["params"]) else: logging.debug("in decrypt:else") data = pickle.loads(data) return sender, data class TCPRandomWalk(TCPRandomWalkBase): """ Wrapper for TCPRandomWalkInternal, mostly a copy of Communication Connect """ def __init__( self, rank, machine_id, mapping, total_procs, addresses_filepath, compress=False, offset=2000, compression_package=None, compression_class=None, sampler="equi", ): """ Constructor Parameters ---------- rank : int Local rank of the process machine_id : int Machine id of the process mapping : decentralizepy.mappings.Mapping uid, rank, machine_id invertible mapping total_procs : int Total number of processes """ super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class) self.sampler = sampler self.send_queue = mp.Queue(1000) self.recv_queue = mp.Queue(1000) # Since we are only adding these do not need to share the same lock, intermediary results may be inconsistent self.lock = mp.Lock() self.total_data = mp.Value(c_long, 0, lock = self.lock) self.total_meta = mp.Value(c_long, 0, lock = self.lock) self.total_bytes = mp.Value(c_long, 0, lock = self.lock) self.flag_running = mp.Value(c_int, 0, lock=False) def connect_neighbors(self, neighbors): """ Spawns TCPRandomWalkInternal. It will connect to the neighbours. This function should only be called once. Parameters ---------- neighbors : list(int) List of neighbors """ self.flag_running.value = 1 self.ctx = mp.start_processes( lambda *args: TCPRandomWalkInternal(*(args[1:])), args=( self.rank, self.machine_id, self.mapping, self.total_procs, self.addresses_filepath, self.send_queue, self.recv_queue, self.flag_running, neighbors, self.total_data, self.total_meta, self.total_bytes, self.compress, self.offset, self.compression_package, self.compression_class, self.sampler, ), start_method="fork", join=False, ) def receive(self, block=True, timeout=None): """ Returns a received message. It blocks if no message has been received. Returns ---------- dict Received and decrypted data """ logging.debug("Receive in TCPRandomWalk") try: return self.recv_queue.get(block=block, timeout=None) # already decrypted logging.debug("Receive in TCPRandomWalk; post get") except Empty: return None, None def send(self, uid, data, encrypt=True): """ Send a message to a process. Parameters ---------- uid : int Neighbor's unique ID data : dict Message as a Python dictionary """ if uid is not None: logging.debug("Send to %i in TCPRandomWalk", uid) self.send_queue.put((uid, data, encrypt)) def disconnect_neighbors(self): """ Disconnects all neighbors. """ print("disconnect_neighbors") self.flag_running.value = 0 time.sleep(4) self.send_queue.close() # this crashes self.recv_queue.close() #del self.lock self.send_queue.join_thread() self.recv_queue.join_thread() self.ctx.join() print(f"disconnect_neighbors: joined {self.uid}") class TCPRandomWalkInternal(TCPRandomWalkBase): def __init__( self, rank, machine_id, mapping, total_procs, addresses_filepath, send_queue: mp.Queue, recv_queue: mp.Queue, flag_running, neighbors, total_data, total_meta, total_bytes, compress=False, offset=2000, compression_package=None, compression_class=None, sampler="equi", ): """ Constructor Parameters ---------- rank : int Local rank of the process machine_id : int Machine id of the process mapping : decentralizepy.mappings.Mapping uid, rank, machine_id invertible mapping total_procs : int Total number of processes addresses_filepath : str JSON file with machine_id -> ip mapping """ super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class) print("post super") try: self.sampler = sampler self.context = zmq.Context() self.router = self.context.socket(zmq.ROUTER) self.router.setsockopt(zmq.IDENTITY, self.identity) self.router.bind(self.addr(rank, machine_id)) self.sent_disconnections = False self.compress = compress self.total_data = total_data self.total_meta = total_meta self.total_bytes = total_bytes self.send_queue = send_queue self.recv_queue = recv_queue self.flag_running = flag_running self.neighbors = neighbors self.random_generator = torch.Generator() self.rw_seed = ( self.random_generator.seed() ) # new random seed from the random device logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed) self.rw_messages_stat = [] self.rw_double_count_stat = [] if self.sampler == "equi": logging.info("rw_samper is rw_sampler_equi") self.rw_sampler = weakref.WeakMethod( self.rw_sampler_equi ) # self.rw_sampler_equi_check_history elif self.sampler == "equi_check_history": logging.info("rw_samper is rw_sampler_equi_check_history") self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history) else: logging.info("rw_samper is rw_sampler_equi (default)") self.rw_sampler = weakref.WeakMethod( self.rw_sampler_equi ) # self.rw_sampler_equi_check_history self.peer_sockets = dict() self.barrier = set() self.connect_neighbors(self.neighbors) self.comm_loop() except BaseException as e: error_message = traceback.format_exc() print(error_message) print("GOT EXCEPTION") logging.debug("GOT EXCEPTION") logging.debug(error_message) while not self.recv_queue.empty(): print(f"{self.uid}: clear rcv") _ = self.recv_queue.get_nowait() self.recv_queue.close() self.recv_queue.join_thread() print(f"{self.uid}: joined recv") while not self.send_queue.empty(): print(f"{self.uid}: clear snd") _ = self.send_queue.get_nowait() print(f"{self.uid}: joined send") self.send_queue.close() self.send_queue.join_thread() print("end") def __del__(self): """ Destroys zmq context """ self.context.destroy(linger=0) logging.info(f"TCPRandomWalkInternal for {self.uid} is destroyed") # exit(1) # uncomment during debugging #self.recv_queue.close() #self.send_queue.close() def connect_neighbors(self, neighbors): """ Connects all neighbors. Sends HELLO. Waits for HELLO. Caches any data received while waiting for HELLOs. Parameters ---------- neighbors : list(int) List of neighbors Raises ------ RuntimeError If received BYE while waiting for HELLO """ logging.info("Sending connection request to neighbors") for uid in neighbors: logging.debug("Connecting to my neighbour: {}".format(uid)) id = str(uid).encode() req = self.context.socket(zmq.DEALER) req.setsockopt(zmq.IDENTITY, self.identity) req.connect(self.addr(*self.mapping.get_machine_and_rank(uid))) self.peer_sockets[id] = req req.send(HELLO) num_neighbors = len(neighbors) while len(self.barrier) < num_neighbors: sender, recv = self.router.recv_multipart() if recv == HELLO: logging.debug("Received {} from {}".format(HELLO, sender)) self.barrier.add(sender) elif recv == BYE: logging.debug("Received {} from {}".format(BYE, sender)) raise RuntimeError( "A neighbour wants to disconnect before training started!" ) else: logging.debug( "Received message from {} @ connect_neighbors".format(sender) ) self.recv_queue.put(self.decrypt(sender, recv)) logging.info("Connected to all neighbors") self.initialized = True def comm_loop(self): # TODO: May want separate loops for send and receive sleep_time = 0.001 while True: successes = 1 flushed = True try: sender, recv = self.router.recv_multipart(flags=zmq.NOBLOCK) logging.debug("comm_loop received from %i", int(sender.decode())) self.receive(sender, recv) successes += 5 flushed = False except zmq.ZMQError: # logging.debug("zmq error due to empty recv_multipart") sleep_time *= 1.1 try: uid, data, compress = self.send_queue.get_nowait() # send may block if send queue is full if uid is not None: logging.debug("comm_loop will send to %i", uid) else: logging.debug("comm_loop will send rw") self.send(uid, data, compress) successes += 5 flushed = False except Empty: sleep_time *= 1.1 # logging.debug("empty send queue %i", self.send_queue.empty()) sleep_time /= successes sleep_time = max(0.00001, sleep_time) sleep_time = min(0.005, sleep_time) # logging.debug("comm loop sleeping for %f", sleep_time) time.sleep(sleep_time) if self.flag_running.value == 0 and flushed: break self.disconnect_neighbors() print(f"DISCONNECTED {self.uid}") def receive(self, sender, recv): """ Returns ONE message received. Returns ---------- dict Received and decrypted data Raises ------ RuntimeError If received HELLO """ if recv == HELLO: logging.debug("Received {} from {}".format(HELLO, sender)) # TODO: how to properly shut down here self.flag_running.value = False self.recv_queue.close() self.send_queue.close() raise RuntimeError( "A neighbour wants to connect when everyone is connected!" ) elif recv == BYE: logging.debug("Received {} from {}".format(BYE, sender)) self.barrier.remove(sender) else: # TODO: here process receive of rw messages logging.debug("Received message from {}".format(sender)) src, data = self.decrypt(sender, recv) # ned new data object such that forwarding does not change the internal fields # We add our selves to visited so, it could trigger "RW message was already once received" new_data = data.copy() if data.get("rw", False): new_data["visited"] = new_data["visited"].copy() src = new_data["visited"][ 0 ] # the original src is the neighbour that sent it to us. logging.debug( "Received message from {} was rw {}".format(sender, data["visited"]) ) if data["fuel"] > 0: new_neighbor = self.rw_sampler()(data) if new_neighbor == None: logging.info( "dropped rw message due to no new neigbor being available: %s", str(data["visited"]), ) else: data["fuel"] -= 1 assert data["fuel"] + 1 == new_data["fuel"] visited = data["visited"] visited.append(self.uid) data["visited"] = visited logging.debug( "Forward rw {} to {}".format(data["visited"], new_neighbor) ) self.send_queue.put((new_neighbor, data, True)) else: # the message has no more fuel so it is dropped logging.info("dropped rw message with fuel %i ", data["fuel"]) self.recv_queue.put((src, new_data)) logging.debug( "Received message from {} after putting into the queue".format(sender) ) def send(self, uid, data, encrypt = True): """ Send a message to a process. Parameters ---------- uid : int Neighbor's unique ID. If it is none then it means we sample the neighbour! data : dict Message as a Python dictionary """ assert self.initialized == True if encrypt: rw = data.get("rw", False) logging.debug("send: rw? {}".format(rw)) to_send = self.encrypt(data) else: rw = False logging.debug("send: rw? {}".format(rw)) to_send = data data_size = len(to_send) self.total_bytes.value += data_size if uid is None and rw: # a rw message if self.flag_running.value == 1: # Do not send rw if we are shutting down uid = self.rw_sampler()(data) logging.debug("send: rw to {}".format(uid)) else: logging.debug("send: rw dropped due to being in shutdown") return id = str(uid).encode() self.peer_sockets[id].send(to_send) logging.debug("{} sent the message to {}.".format(self.uid, uid)) logging.info("Sent this round: {}".format(data_size)) def disconnect_neighbors(self): """ Disconnects all neighbors. """ assert self.initialized == True if not self.sent_disconnections: logging.info("Disconnecting neighbors") for sock in self.peer_sockets.values(): sock.send(BYE) self.sent_disconnections = True while len(self.barrier): sender, recv = self.router.recv_multipart() if recv == BYE: logging.debug("Received {} from {}".format(BYE, sender)) self.barrier.remove(sender) else: # this can happen now due to async logging.info("Received unexpected message from {}".format(sender)) sender, data = self.decrypt(sender, recv) if data.get("rw", False): logging.info("Message was rw {}".format(data["visited"])) else: logging.info("Message was normal") # raise RuntimeError( # "Received a message when expecting BYE from {}".format(sender) # ) for sock in self.peer_sockets.values(): sock.close() self.router.close() def rw_sampler_equi(self, message): index = torch.randint( 0, len(self.neighbors), size=(1,), generator=self.random_generator ).item() logging.debug( "rw_sampler_equi selected index {} of {} {}".format( index, self.neighbors, type(self.neighbors) ) ) return list(self.neighbors)[index] def rw_sampler_equi_check_history(self, message): if message is None: # RW starts from here index = torch.randint( 0, len(self.neighbors), size=(1,), generator=self.random_generator ).item() logging.debug( "rw_sampler_equi_check_history selected index {} of {} {}".format( index, self.neighbors, type(self.neighbors) ) ) return list(self.neighbors)[index] else: visited = set(message["visited"]) neighbors = self.neighbors # is already a set possible_neigbors = neighbors.difference(visited) if len(possible_neigbors) == 0: return None else: index = torch.randint( 0, len(possible_neigbors), size=(1,), generator=self.random_generator, ).item() return list(possible_neigbors)[index] def rw_sampler_mh(self, message): # Metro hastings version of the sampler # Samples the neighbour based on the MH weights, allows self loops (will just decrease fuel and reroll!) pass