Skip to content
Snippets Groups Projects
TCPRandomWalk.py 23 KiB
Newer Older
Jeffrey Wigger's avatar
Jeffrey Wigger committed
import importlib
import json
import logging
import lzma
import pickle
import time
import weakref
from collections import deque
Jeffrey Wigger's avatar
Jeffrey Wigger committed
from ctypes import c_int, c_long
from multiprocessing.sharedctypes import Value
from queue import Empty
Jeffrey Wigger's avatar
Jeffrey Wigger committed
import traceback
from multiprocessing import Lock
Jeffrey Wigger's avatar
Jeffrey Wigger committed
import faulthandler
faulthandler.enable()

import torch.multiprocessing as mp

from decentralizepy.communication.Communication import Communication

HELLO = b"HELLO"
BYE = b"BYE"


class TCPRandomWalkBase(Communication):
    """
    TCPRandomWalkBase that copies only the encrypt and decrypt functions from TCP.py
    This dummy is needed as the sharing interfaces can call encrpt and decrypt,
    so we need the functions both in TCPRandomWalk and TCPRandomWalkInternal

    """

    def addr(self, rank, machine_id):
        """
        Returns TCP address of the process.

        Parameters
        ----------
        rank : int
            Local rank of the process
        machine_id : int
            Machine id of the process

        Returns
        -------
        str
            Full address of the process using TCP

        """
        machine_addr = self.ip_addrs[str(machine_id)]
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        port = rank + self.offset
        return "tcp://{}:{}".format(machine_addr, port)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
    
    def __init__(
        self,
        rank,
        machine_id,
        mapping,
        total_procs,
        addresses_filepath,
        compress=False,
        offset=2000,
        compression_package=None,
        compression_class=None,
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank of the process
        machine_id : int
            Machine id of the process
        mapping : decentralizepy.mappings.Mapping
            uid, rank, machine_id invertible mapping
        total_procs : int
            Total number of processes
        addresses_filepath : str
            JSON file with machine_id -> ip mapping
        compression_package : str
            Import path of a module that implements the compression.Compression.Compression class
        compression_class : str
            Name of the compression class inside the compression package

        """
        super().__init__(rank, machine_id, mapping, total_procs)
        self.addresses_filepath = addresses_filepath
        self.compress = compress
        self.offset = 10000 + offset
        self.compression_package = compression_package
        self.compression_class = compression_class


        with open(addresses_filepath) as addrs:
            self.ip_addrs = json.load(addrs)

        self.identity = str(self.uid).encode()

        if compression_package and compression_class:
            compressor_module = importlib.import_module(compression_package)
            compressor_class = getattr(compressor_module, compression_class)
            self.compressor = compressor_class()
            logging.info(f"Using the {compressor_class} to compress the data")
        else:
            assert not self.compress

        self.total_data = 0 # dummy values
        self.total_meta = 0

    def encrypt(self, data):
        """
        Encode data as python pickle.

        Parameters
        ----------
        data : dict
            Data dict to send

        Returns
        -------
        byte
            Encoded data

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        logging.debug("in encrypt")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            logging.debug("in encrypt: compress")
            if "indices" in data:
                data["indices"] = self.compressor.compress(data["indices"])

            assert "params" in data
            data["params"] = self.compressor.compress_float(data["params"])
            data_len = len(pickle.dumps(data["params"]))
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            # the compressed meta data gets only a few bytes smaller after pickling
            self.total_meta.value += (len(output) - data_len)
            self.total_data.value += data_len
        else:
            logging.debug("in encrypt: else")
            output = pickle.dumps(data)
            # centralized testing uses its own instance
            if type(data) == dict:
                assert "params" in data
                data_len = len(pickle.dumps(data["params"]))
                self.total_meta.value += (len(output) - data_len)
                self.total_data.value += data_len
        return output

    def decrypt(self, sender, data):
        """
        Decode received pickle data.

        Parameters
        ----------
        sender : byte
            sender of the data
        data : byte
            Data received

        Returns
        -------
        tuple
            (sender: int, data: dict)

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        logging.debug("in decrypt")
        sender = int(sender.decode())
        if self.compress:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            logging.debug("in decrypt:comp")
            data = pickle.loads(data)
            if "indices" in data:
                data["indices"] = self.compressor.decompress(data["indices"])
            if "params" in data:
                data["params"] = self.compressor.decompress_float(data["params"])
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            logging.debug("in decrypt:else")
            data = pickle.loads(data)
        return sender, data


class TCPRandomWalk(TCPRandomWalkBase):
    """
    Wrapper for TCPRandomWalkInternal, mostly a copy of Communication
    Connect

    """

    def __init__(
        self,
        rank,
        machine_id,
        mapping,
        total_procs,
        addresses_filepath,
        compress=False,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        offset=2000,
        compression_package=None,
        compression_class=None,
        sampler="equi",
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank of the process
        machine_id : int
            Machine id of the process
        mapping : decentralizepy.mappings.Mapping
            uid, rank, machine_id invertible mapping
        total_procs : int
            Total number of processes

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class)
        self.sampler = sampler
        self.send_queue = mp.Queue(1000)
        self.recv_queue = mp.Queue(1000)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        # Since we are only adding these do not need to share the same lock, intermediary results may be inconsistent
        self.lock = mp.Lock()
        self.total_data = mp.Value(c_long, 0, lock = self.lock)
        self.total_meta = mp.Value(c_long, 0, lock = self.lock)
        self.total_bytes = mp.Value(c_long, 0, lock = self.lock)
        self.flag_running = mp.Value(c_int, 0, lock=False)

    def connect_neighbors(self, neighbors):
        """
        Spawns TCPRandomWalkInternal. It will connect to the neighbours.
        This function should only be called once.

        Parameters
        ----------
        neighbors : list(int)
            List of neighbors

        """
        self.flag_running.value = 1
        self.ctx = mp.start_processes(
            lambda *args: TCPRandomWalkInternal(*(args[1:])),
            args=(
                self.rank,
                self.machine_id,
                self.mapping,
                self.total_procs,
                self.addresses_filepath,
                self.send_queue,
                self.recv_queue,
                self.flag_running,
                neighbors,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                self.total_data,
                self.total_meta,
                self.total_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                self.offset,
                self.compression_package,
                self.compression_class,
                self.sampler,
            ),
            start_method="fork",
            join=False,
        )

    def receive(self, block=True, timeout=None):
        """
        Returns a received message. It blocks if no message has been received.

        Returns
        ----------
        dict
            Received and decrypted data

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        logging.debug("Receive in TCPRandomWalk")
        try:
            return self.recv_queue.get(block=block, timeout=None)  # already decrypted
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            logging.debug("Receive in TCPRandomWalk; post get")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
    def send(self, uid, data, encrypt=True):
        """
        Send a message to a process.

        Parameters
        ----------
        uid : int
            Neighbor's unique ID
        data : dict
            Message as a Python dictionary

        """
        if uid is not None:
            logging.debug("Send to %i in TCPRandomWalk", uid)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.send_queue.put((uid, data, encrypt))

    def disconnect_neighbors(self):
        """
        Disconnects all neighbors.

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        print("disconnect_neighbors")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.flag_running.value = 0
        time.sleep(4)
        self.send_queue.close()  # this crashes
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.recv_queue.close()
        #del self.lock
        self.send_queue.join_thread()
        self.recv_queue.join_thread()
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.ctx.join()
        print(f"disconnect_neighbors: joined {self.uid}")


class TCPRandomWalkInternal(TCPRandomWalkBase):
    def __init__(
        self,
        rank,
        machine_id,
        mapping,
        total_procs,
        addresses_filepath,
        send_queue: mp.Queue,
        recv_queue: mp.Queue,
        flag_running,
        neighbors,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        total_data,
        total_meta,
        total_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        offset=2000,
        compression_package=None,
        compression_class=None,
        sampler="equi",
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank of the process
        machine_id : int
            Machine id of the process
        mapping : decentralizepy.mappings.Mapping
            uid, rank, machine_id invertible mapping
        total_procs : int
            Total number of processes
        addresses_filepath : str
            JSON file with machine_id -> ip mapping

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset,
                         compression_package, compression_class)
        print("post super")
        try:
            self.sampler = sampler
            self.context = zmq.Context()
            self.router = self.context.socket(zmq.ROUTER)
            self.router.setsockopt(zmq.IDENTITY, self.identity)
            self.router.bind(self.addr(rank, machine_id))
            self.sent_disconnections = False
            self.compress = compress
            self.total_data = total_data
            self.total_meta = total_meta
            self.total_bytes = total_bytes

            self.send_queue = send_queue
            self.recv_queue = recv_queue
            self.flag_running = flag_running
            self.neighbors = neighbors

            self.random_generator = torch.Generator()
            self.rw_seed = (
                self.random_generator.seed()
            )  # new random seed from the random device
            logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed)
            self.rw_messages_stat = []
            self.rw_double_count_stat = []
            if self.sampler == "equi":
                logging.info("rw_samper is rw_sampler_equi")
                self.rw_sampler = weakref.WeakMethod(
                    self.rw_sampler_equi
                )  # self.rw_sampler_equi_check_history
            elif self.sampler == "equi_check_history":
                logging.info("rw_samper is rw_sampler_equi_check_history")
                self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history)
            else:
                logging.info("rw_samper is rw_sampler_equi (default)")
                self.rw_sampler = weakref.WeakMethod(
                    self.rw_sampler_equi
                )  # self.rw_sampler_equi_check_history
            self.peer_sockets = dict()
            self.barrier = set()
            self.connect_neighbors(self.neighbors)


            self.comm_loop()
        except BaseException as e:
            error_message = traceback.format_exc()
            print(error_message)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            print("GOT EXCEPTION")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            logging.debug("GOT EXCEPTION")
            logging.debug(error_message)
Jeffrey Wigger's avatar
Jeffrey Wigger committed

        while not self.recv_queue.empty():
            print(f"{self.uid}: clear rcv")
            _ = self.recv_queue.get_nowait()
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.recv_queue.close()
        self.recv_queue.join_thread()
        print(f"{self.uid}: joined recv")
        while not self.send_queue.empty():
            print(f"{self.uid}: clear snd")
            _ = self.send_queue.get_nowait()
        print(f"{self.uid}: joined send")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.send_queue.close()
        self.send_queue.join_thread()
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        print("end")

    def __del__(self):
        """
        Destroys zmq context

        """
        self.context.destroy(linger=0)
        logging.info(f"TCPRandomWalkInternal for {self.uid} is destroyed")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        # exit(1)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        #self.recv_queue.close()
        #self.send_queue.close()

    def connect_neighbors(self, neighbors):
        """
        Connects all neighbors. Sends HELLO. Waits for HELLO.
        Caches any data received while waiting for HELLOs.

        Parameters
        ----------
        neighbors : list(int)
            List of neighbors

        Raises
        ------
        RuntimeError
            If received BYE while waiting for HELLO

        """
        logging.info("Sending connection request to neighbors")
        for uid in neighbors:
            logging.debug("Connecting to my neighbour: {}".format(uid))
            id = str(uid).encode()
            req = self.context.socket(zmq.DEALER)
            req.setsockopt(zmq.IDENTITY, self.identity)
            req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
            self.peer_sockets[id] = req
            req.send(HELLO)

        num_neighbors = len(neighbors)
        while len(self.barrier) < num_neighbors:
            sender, recv = self.router.recv_multipart()

            if recv == HELLO:
                logging.debug("Received {} from {}".format(HELLO, sender))
                self.barrier.add(sender)
            elif recv == BYE:
                logging.debug("Received {} from {}".format(BYE, sender))
                raise RuntimeError(
                    "A neighbour wants to disconnect before training started!"
                )
            else:
                logging.debug(
                    "Received message from {} @ connect_neighbors".format(sender)
                )

                self.recv_queue.put(self.decrypt(sender, recv))

        logging.info("Connected to all neighbors")
        self.initialized = True

    def comm_loop(self):
        # TODO: May want separate loops for send and receive
        sleep_time = 0.001
        while True:
            successes = 1
            flushed = True
            try:
                sender, recv = self.router.recv_multipart(flags=zmq.NOBLOCK)
                logging.debug("comm_loop received from %i", int(sender.decode()))
                self.receive(sender, recv)
                successes += 5
                flushed = False
            except zmq.ZMQError:
                # logging.debug("zmq error due to empty recv_multipart")
                sleep_time *= 1.1

            try:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                uid, data, compress = self.send_queue.get_nowait()
                # send may block if send queue is full
                if uid is not None:
                    logging.debug("comm_loop will send to %i", uid)
                else:
                    logging.debug("comm_loop will send rw")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                self.send(uid, data, compress)
                successes += 5
                flushed = False
            except Empty:
                sleep_time *= 1.1
                # logging.debug("empty send queue %i", self.send_queue.empty())

            sleep_time /= successes
            sleep_time = max(0.00001, sleep_time)
            sleep_time = min(0.005, sleep_time)
            # logging.debug("comm loop sleeping for %f", sleep_time)
            time.sleep(sleep_time)

            if self.flag_running.value == 0 and flushed:
                break
        self.disconnect_neighbors()
        print(f"DISCONNECTED {self.uid}")

    def receive(self, sender, recv):
        """
        Returns ONE message received.

        Returns
        ----------
        dict
            Received and decrypted data

        Raises
        ------
        RuntimeError
            If received HELLO

        """
        if recv == HELLO:
            logging.debug("Received {} from {}".format(HELLO, sender))
            # TODO: how to properly shut down here
            self.flag_running.value = False
            self.recv_queue.close()
            self.send_queue.close()
            raise RuntimeError(
                "A neighbour wants to connect when everyone is connected!"
            )
        elif recv == BYE:
            logging.debug("Received {} from {}".format(BYE, sender))
            self.barrier.remove(sender)
        else:
            # TODO: here process receive of rw messages
            logging.debug("Received message from {}".format(sender))
            src, data = self.decrypt(sender, recv)
            # ned new data object such that forwarding does not change the internal fields
            # We add our selves to visited so, it could trigger "RW message was already once received"
            new_data = data.copy()
            if data.get("rw", False):
                new_data["visited"] = new_data["visited"].copy()
                src = new_data["visited"][
                    0
                ]  # the original src is the neighbour that sent it to us.
                logging.debug(
                    "Received message from {} was rw {}".format(sender, data["visited"])
                )
                if data["fuel"] > 0:
                    new_neighbor = self.rw_sampler()(data)
                    if new_neighbor == None:
                        logging.info(
                            "dropped rw message due to no new neigbor being available: %s",
                            str(data["visited"]),
                        )
                    else:
                        data["fuel"] -= 1
                        assert data["fuel"] + 1 == new_data["fuel"]
                        visited = data["visited"]
                        visited.append(self.uid)
                        data["visited"] = visited
                        logging.debug(
                            "Forward rw {} to {}".format(data["visited"], new_neighbor)
                        )
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        self.send_queue.put((new_neighbor, data, True))
                else:
                    # the message has no more fuel so it is dropped
                    logging.info("dropped rw message with fuel %i ", data["fuel"])
            self.recv_queue.put((src, new_data))
            logging.debug(
                "Received message from {} after putting into the queue".format(sender)
            )

Jeffrey Wigger's avatar
Jeffrey Wigger committed
    def send(self, uid, data, encrypt = True):
        """
        Send a message to a process.

        Parameters
        ----------
        uid : int
            Neighbor's unique ID. If it is none then it means we sample the neighbour!
        data : dict
            Message as a Python dictionary

        """
        assert self.initialized == True
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        if encrypt:
            rw = data.get("rw", False)
            logging.debug("send: rw? {}".format(rw))
            to_send = self.encrypt(data)
        else:
            rw = False
            logging.debug("send: rw? {}".format(rw))
            to_send = data
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.total_bytes.value += data_size
        if uid is None and rw:  # a rw message
            if self.flag_running.value == 1:  # Do not send rw if we are shutting down
                uid = self.rw_sampler()(data)
                logging.debug("send: rw to {}".format(uid))
            else:
                logging.debug("send: rw dropped due to being in shutdown")
                return
        id = str(uid).encode()
        self.peer_sockets[id].send(to_send)
        logging.debug("{} sent the message to {}.".format(self.uid, uid))
        logging.info("Sent this round: {}".format(data_size))

    def disconnect_neighbors(self):
        """
        Disconnects all neighbors.

        """
        assert self.initialized == True
        if not self.sent_disconnections:
            logging.info("Disconnecting neighbors")
            for sock in self.peer_sockets.values():
                sock.send(BYE)
            self.sent_disconnections = True
            while len(self.barrier):
                sender, recv = self.router.recv_multipart()
                if recv == BYE:
                    logging.debug("Received {} from {}".format(BYE, sender))
                    self.barrier.remove(sender)
                else:
                    # this can happen now due to async
                    logging.info("Received unexpected message from {}".format(sender))
                    sender, data = self.decrypt(sender, recv)
                    if data.get("rw", False):
                        logging.info("Message was rw {}".format(data["visited"]))
                    else:
                        logging.info("Message was normal")
                    # raise RuntimeError(
                    #    "Received a message when expecting BYE from {}".format(sender)
                    # )
        for sock in self.peer_sockets.values():
            sock.close()
        self.router.close()

    def rw_sampler_equi(self, message):

        index = torch.randint(
            0, len(self.neighbors), size=(1,), generator=self.random_generator
        ).item()
        logging.debug(
            "rw_sampler_equi selected index {} of {} {}".format(
                index, self.neighbors, type(self.neighbors)
            )
        )
        return list(self.neighbors)[index]

    def rw_sampler_equi_check_history(self, message):
        if message is None:  # RW starts from here
            index = torch.randint(
                0, len(self.neighbors), size=(1,), generator=self.random_generator
            ).item()
            logging.debug(
                "rw_sampler_equi_check_history selected index {} of {} {}".format(
                    index, self.neighbors, type(self.neighbors)
                )
            )
            return list(self.neighbors)[index]
        else:
            visited = set(message["visited"])
            neighbors = self.neighbors  # is already a set
            possible_neigbors = neighbors.difference(visited)
            if len(possible_neigbors) == 0:
                return None
            else:
                index = torch.randint(
                    0,
                    len(possible_neigbors),
                    size=(1,),
                    generator=self.random_generator,
                ).item()
                return list(possible_neigbors)[index]

    def rw_sampler_mh(self, message):
        # Metro hastings version of the sampler
        # Samples the neighbour based on the MH weights, allows self loops (will just decrease fuel and reroll!)
        pass