Skip to content
Snippets Groups Projects
TCPRandomWalkRouting.py 41.7 KiB
Newer Older
Jeffrey Wigger's avatar
Jeffrey Wigger committed
import importlib
import faulthandler
import json
import logging
import lzma
import multiprocessing as mp
import pickle
import time
import weakref
from collections import deque
Jeffrey Wigger's avatar
Jeffrey Wigger committed
from ctypes import c_int, c_long
from multiprocessing.sharedctypes import Value
from queue import Empty
Jeffrey Wigger's avatar
Jeffrey Wigger committed
import traceback

import torch
import zmq

from decentralizepy.communication.TCPRandomWalk import TCPRandomWalkBase

faulthandler.enable()

import torch.multiprocessing as mp

from decentralizepy.communication.Communication import Communication

HELLO = b"HELLO"
BYE = b"BYE"
NOBYE = b"NOBYE"


class TCPRandomWalkRouting(TCPRandomWalkBase):
    """
    Wrapper for TCPRandomWalkInternal, mostly a copy of Communication
    Connect

    """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
    
    def __init__(
        self,
        rank,
        machine_id,
        mapping,
        total_procs,
        addresses_filepath,
        compress=False,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        offset=2000,
        compression_package=None,
        compression_class=None,
        sampler="equi",
        neighbor_bound=(3, 5),
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank of the process
        machine_id : int
            Machine id of the process
        mapping : decentralizepy.mappings.Mapping
            uid, rank, machine_id invertible mapping
        total_procs : int
            Total number of processes

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class)
        self.sampler = sampler
        self.send_queue = mp.Queue(1000)
        self.recv_queue = mp.Queue(1000)
        (self.our_pipe, self.their_pipe) = mp.Pipe()
        self.neighbors = None
        self.neighbor_bound = neighbor_bound
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.lock = mp.Lock()
        self.total_data = mp.Value(c_long, 0, lock = self.lock)
        self.total_meta = mp.Value(c_long, 0, lock = self.lock)
        self.total_bytes = mp.Value(c_long, 0, lock = self.lock)
        self.flag_running = mp.Value(c_int, 0, lock=False)

    def connect_neighbors(self, neighbors):
        """
        Spawns TCPRandomWalkInternal. It will connect to the neighbours.
        This function should only be called once.

        Parameters
        ----------
        neighbors : list(int)
            List of neighbors

        """
        self.flag_running.value = 1
        self.ctx = mp.start_processes(
            lambda *args: TCPRandomWalkRoutingInternal(*(args[1:])),
            args=(
                self.rank,
                self.machine_id,
                self.mapping,
                self.total_procs,
                self.addresses_filepath,
                self.send_queue,
                self.recv_queue,
                self.their_pipe,
                self.flag_running,
                neighbors,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                self.total_data,
                self.total_meta,
                self.total_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                self.offset,
                self.compression_package,
                self.compression_class,
                self.sampler,
                self.neighbor_bound,
            ),
            start_method="fork",
            join=False,
        )

    def receive(self, block=True, timeout=None):
        """
        Returns a received message. It blocks if no message has been received.

        Returns
        ----------
        dict
            sender : data already decrypted

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        logging.debug("Receive in TCPRandomWalk")
        try:
            return self.recv_queue.get(block=block, timeout=None)  # already decrypted
        except Empty:
            return None

    def send(self, uid, data):
        """
        Send a message to a process.

        Parameters
        ----------
        uid : int
            Neighbor's unique ID
        data : dict
            Message as a Python dictionary

        """
        if uid is not None:
            logging.debug(f"Send to {uid} in TCPRandomWalk")
        self.send_queue.put((uid, data))

    def get_current_neighbors(self):
        """
        Get the currently connected neighbors

        Returns
        ----------
        list
            the neighbors

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        # This function is currently never used
        if not self.our_pipe.poll():
            return self.neighbors
        else:
            self.neighbors = self.our_pipe.recv()
            return self.neighbors

    def disconnect_neighbors(self):
        """
        Disconnects all neighbors.

        """
        time.sleep(200)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        print("disconnect_neighbors")
        time.sleep(20)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.send_queue.close()  # this crashes
        self.recv_queue.close()

        #del self.lock
        self.send_queue.join_thread()
        self.recv_queue.join_thread()
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        self.our_pipe.close()
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        print(f"disconnect_neighbors: joined {self.uid}")


class TCPRandomWalkRoutingInternal(TCPRandomWalkBase):
    def __init__(
        self,
        rank,
        machine_id,
        mapping,
        total_procs,
        addresses_filepath,
        send_queue: mp.Queue,
        recv_queue: mp.Queue,
        pipe: mp.Pipe,
        flag_running,
        neighbors,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        total_data,
        total_meta,
        total_bytes,
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        offset=2000,
        compression_package=None,
        compression_class=None,
        sampler="equi",
        neighbor_bound=(3, 5),
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank of the process
        machine_id : int
            Machine id of the process
        mapping : decentralizepy.mappings.Mapping
            uid, rank, machine_id invertible mapping
        total_procs : int
            Total number of processes
        addresses_filepath : str
            JSON file with machine_id -> ip mapping

        """
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset,
                         compression_package, compression_class)
        try:
            self.sampler = sampler
            self.context = zmq.Context()
            self.router = self.context.socket(zmq.ROUTER)
            self.router.setsockopt(zmq.IDENTITY, self.identity)
            self.router.bind(self.addr(rank, machine_id))
            self.sent_disconnections = False
            self.compress = compress
            self.neighbor_bound_lower = neighbor_bound[0]
            self.neighbor_bound_upper = neighbor_bound[1]
            self.total_data = total_data
            self.total_meta = total_meta
            self.total_bytes = total_bytes

            self.send_queue = send_queue
            self.recv_queue = recv_queue
            self.pipe = pipe
            self.flag_running = flag_running
            self.init_neighbors = neighbors

            self.current_round = 0
            self.current_data = None
            self.future_neighbours = {}
            self.outgoing_request = {}
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            self.received_data = {}
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            self.init = False
            self.outgoing_byes = {}
            self.future_byes = {}
            self.this_round_bye = set()

            self.random_generator = torch.Generator()
            self.rw_seed = (
                self.random_generator.seed()
            )  # new random seed from the random device
            logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed)
            self.rw_messages_stat = []
            self.rw_double_count_stat = []
            if self.sampler == "equi":
                logging.info("rw_samper is rw_sampler_equi")
                self.rw_sampler = weakref.WeakMethod(
                    self.rw_sampler_equi
                )  # self.rw_sampler_equi_check_history
            elif self.sampler == "equi_check_history":
                logging.info("rw_samper is rw_sampler_equi_check_history")
                self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history)
            else:
                logging.info("rw_samper is rw_sampler_equi (default)")
                self.rw_sampler = weakref.WeakMethod(
                    self.rw_sampler_equi
                )  # self.rw_sampler_equi_check_history
            self.peer_sockets = dict()
            self.current_neighbors = set()
            self.connect_neighbors(self.init_neighbors)

            self.comm_loop()
        except BaseException as e:
            error_message = traceback.format_exc()
            print(error_message)
            print(f"GOT EXCEPTION {self.uid}")
            logging.info("GOT EXCEPTION")
            logging.info(error_message)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        while not self.recv_queue.empty():
            print(f"{self.uid}: clear rcv")
            _ = self.recv_queue.get_nowait()
        self.recv_queue.close()
        self.recv_queue.join_thread()
        print(f"{self.uid}: joined recv")
        while not self.send_queue.empty():
            print(f"{self.uid}: clear snd")
            _ = self.send_queue.get_nowait()
        print(f"{self.uid}: joined send")
        self.send_queue.close()
        self.send_queue.join_thread()
        print("end")

    def __del__(self):
        """
        Destroys zmq context

        """
        self.context.destroy(linger=0)
        logging.info(f"TCPRandomWalkInternal for {self.uid} is destroyed")
        # uncomment during debugging
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        #self.recv_queue.close()
        #self.send_queue.close()

    def connect_neighbors(self, neighbors):
        """
        Connects all neighbors. Sends HELLO. Waits for HELLO.
        Caches any data received while waiting for HELLOs.

        Parameters
        ----------
        neighbors : list(int)
            List of neighbors

        Raises
        ------
        RuntimeError
            If received BYE while waiting for HELLO

        """
        logging.info("Sending connection request to neighbors")
        for uid in neighbors:
            self.connect(uid)

    def connect(self, uid, initiating=True):
        to_send = (HELLO, self.current_round)
        if uid in self.current_neighbors:
            to_send = (HELLO, "fw at neighbor")
            self.peer_sockets[uid].send(self.encrypt(to_send))
            return
        logging.debug("Connecting to new neighbour: {}".format(uid))
        req = self.context.socket(zmq.DEALER)
        req.setsockopt(zmq.IDENTITY, self.identity)
        req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
        self.peer_sockets[uid] = req
        req.send(self.encrypt(to_send))
        if initiating:
            self.outgoing_request[uid] = self.current_round

    def comm_loop(self):
        # TODO: May want separate loops for send and receive
        sleep_time = 0.001
        while True:
            flushed = True
            successes = 1
            try:
                sender, recv = self.router.recv_multipart(flags=zmq.NOBLOCK)
                logging.debug("comm_loop received from %i", int(sender.decode()))
                self.receive(sender, recv)
                successes += 5
                flushed = False
            except zmq.ZMQError:
                # logging.debug("zmq error due to empty recv_multipart")
                sleep_time *= 1.1
            if self.init:
                try:
                    uid, data = self.send_queue.get_nowait()
                    # send may block if send queue is full
                    if uid is not None:
                        logging.debug(f"comm_loop will send to {uid}")
                    else:
                        logging.debug("comm_loop will send rw")
                    self.send(uid, data)
                    successes += 5
                    flushed = False
                except Empty:
                    sleep_time *= 1.1
                    # logging.debug("empty send queue %i", self.send_queue.empty())
            self.update_neighbors()
            self.can_deliver()
            sleep_time /= successes
            sleep_time = max(0.00001, sleep_time)
            sleep_time = min(0.005, sleep_time)
            # logging.debug("comm loop sleeping for %f", sleep_time)
            time.sleep(sleep_time)

            if self.flag_running.value == 0 and flushed:
                break
        self.disconnect_neighbors()
        print(f"DISCONNECTED {self.uid}")

    def receive(self, sender, recv):
        """
        Returns ONE message received.

        Returns
        ----------
        dict
            Received and decrypted data

        Raises
        ------
        RuntimeError
            If received HELLO

        """
        src, data = self.decrypt(sender, recv)
        if type(data) == tuple and data[0] == HELLO:
            logging.debug("Received {} from {}".format(HELLO, src))
            if src in self.current_neighbors:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                if data[1] != "fw at neighbor": # TODO: this is wrong
                    logging.critical(
                        "{} wants to connect when already connected!".format(HELLO, src)
                    )
                    raise RuntimeError(
                        "{} wants to connect when already connected!".format(HELLO, src)
                    )
                else:
                    logging.info("fw arrived at a neighbour")
            else:
                logging.debug("Received {} from {}".format(data, src))
                self.add_neighbors(src, data[1])
        elif type(data) == tuple and data[0] == NOBYE:
            print(f"{self.uid} received no bye {data} from {src}")
            del self.outgoing_byes[src]
        elif type(data) == tuple and data[0] == BYE:
            logging.debug("Received {} from {}".format(BYE, src))
            if src in self.current_neighbors:
                self.disconnect_request_handler(src, data)
            else:
                found = False
                print(f"future neighbours are {self.future_neighbours}")
                keys = list(self.future_neighbours.keys())
                for round in keys:
                    future_conns = self.future_neighbours[round]
                    print(f"future conns {future_conns} at round {round}")
                    if (src, False) in future_conns or (src, True) in future_conns:
                        print(
                            f"RECEIVED BYE for a future neighbor {src} for round {round}"
                        )
                        logging.info(
                            f"RECEIVED BYE for a future neighbor {src} for round {round}"
                        )
                        #if (src, False) in future_conns:
                        #    future_conns.remove((src, False))
                        #else:
                        #    future_conns.remove((src, True))
                        # Not calling disconnect_request_handler as we are not yet officially coneected
                        #self.peer_sockets[sender].send(
                        #    self.encrypt((BYE, self.current_round))
                        #)
                        # assert data[1] == round
                        # above does not work as we also need to remove it from peer_sockets and
                        # the other side may have already received
                        self.future_byes.setdefault(data[1], []).append(
                            (src, False) # Bye did not originate here
                        found = True
                    #if len(future_conns) == 0:
                    #    logging.info(
                    #        f"There are now no new future neighbors in round {round}"
                    #    )
                    #    del self.future_neighbours[round]
                if not found:
                    logging.critical(
                        "Received {} from {} despite it not being connected".format(
                            BYE, src
                        )
                    )
                    raise RuntimeError(
                        "Received {} from {} despite it not being connected".format(
                            BYE, src
                        )
                    )
        else:
            logging.debug("Received message from {}".format(src))
            # ned new data object such that forwarding does not change the internal fields
            # We add our selves to visited so, it could trigger "RW message was already once received"
            new_data = data.copy()
            if data.get("rw", False):
                new_data["visited"] = new_data["visited"].copy()
                src = new_data["visited"][
                    0
                ]  # the original src is the neighbour that sent it to us.
                logging.debug(
                    "Received message from {} was rw {}".format(src, data["visited"])
                )
                if data["fuel"] > 0:
                    new_neighbor = self.rw_sampler()(data)
                    if new_neighbor == None:
                        if src != self.uid:
                            logging.info(
                                "RW message is delivered here due to no new neighbors being available: %s",
                                str(data["visited"]),
                            )
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                            # TODO: check if not already a neighbor
                            #self.connect(new_data["routing_info"])
                            return
                        else:
                            logging.info(
                                "RW message not deliver since it originated from here!*"
                            )
                            return
                    else:
                        data["fuel"] -= 1
                        assert data["fuel"] + 1 == new_data["fuel"]
                        visited = data["visited"]
                        visited.append(self.uid)
                        data["visited"] = visited
                        logging.debug(
                            "Forward rw {} to {}".format(data["visited"], new_neighbor)
                        )
                        self.send_queue.put((new_neighbor, data))
                        # We only deliver rw messages at the destination
                        return
                else:
                    # the message has no more fuel so it is dropped
                    if src != self.uid:
                        # TODO: keep track of all rw we send!
                        # TODO: what if it is our neighbour?
                        logging.info(
                            "RW message from originator %i reached its destination with fuel %i ",
                            src,
                            data["fuel"],
                        )
                        self.connect(new_data["routing_info"])
                        return
                    else:
                        logging.info(
                            "RW message not deliver since it originated from here!"
                        )
                        return
            else:
                # self.recv_queue.put((src, new_data))
                tmp = self.received_data.setdefault(new_data["iteration"], dict())
                if src in tmp:
                    logging.critical(
                        "Received data from {} twice for iteration {}".format(
                            src, new_data["iteration"]
                        )
                    )
                    raise RuntimeError(
                        "Received data from {} twice for iteration {}".format(
                            src, new_data["iteration"]
                        )
                    )
                else:
                    tmp[src] = new_data  # todo: maybe store in deque
                logging.debug(
                    "Received message from {} for iter {} after putting into ".format(
                        src, new_data["iteration"]
                    )
                )

    def can_deliver(self):
        # check if only current neighbours are in it. and check that outgoing_request is empty and future_neighbours too
        if self.current_data == None:  # have not yet received this rounds data
            return
        if len(self.outgoing_request) != 0:
            logging.debug(f"still have outgoing requests {self.outgoing_request}")
            return
        if len(self.future_neighbours.get(self.current_round, [])) != 0:
            logging.debug(
                f"still have future neighbours {self.future_neighbours.get(self.current_round, [])}"
            )
            return
        if len(self.outgoing_byes) != 0:
            logging.debug(f"still have outgoing byes {len(self.outgoing_byes)}")
            return
        currently_received = self.received_data.get(self.current_round, dict()).keys()
        if len(currently_received) == 0:
            logging.debug(f"No received data yet")
            return
        received_from = set(currently_received)
        logging.debug(f"received data is {received_from} need: {self.current_neighbors}")
        # Cannot test for equality here with self.current_neighbors since others may have already disconnected.
        # TODO: make sure we got the model for the neighbours already removed due to by
        if received_from == self.current_neighbors:
            logging.debug(f"can deliver in round {self.current_round}")
            for neighbor in self.this_round_bye:
                logging.debug(f"remove neighbor {neighbor}")
                self.current_neighbors.remove(neighbor)
                self.peer_sockets[neighbor].close()
                del self.peer_sockets[neighbor]
            logging.info(
                f"At the end of {self.current_round} round the neighbors are {self.current_neighbors}"
            )
            self.this_round_bye = set()
            self.current_data = None
            self.current_round += 1
            self.recv_queue.put(self.received_data[self.current_round - 1])
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            # establishes a connection with future neighbors.
            # must be called here, as some of these future neighbors may also be in future byes
            self.update_neighbors()
            del self.received_data[self.current_round - 1]

    def add_neighbors(self, sender, round):
        logging.info(f"Processing a HELLO from {sender} marked with {round}")
        if self.current_round == 0:
            if sender not in self.current_neighbors:
                logging.info(f"Added {sender} to current neighbors")
                self.current_neighbors.add(sender)
            else:
                logging.critical(f"Received a hello from {sender} with {round}")
                raise RuntimeError(f"Received a hello from {sender} with {round}")
            if self.current_neighbors == self.init_neighbors:
                logging.info(f"Added all initial neighbours to current neighbors")
                self.init = True
                self.outgoing_request = dict()
            return
        if sender in self.outgoing_request:
            if (
                round == self.current_round
            ):  # we do not advance as long as there outgoing requests
                # current data is none if sharing has not yet started for this round, i.e. we are still training locally
                if self.current_data != None:
                    self.send(
                        sender, self.current_data
                    )  # messaging is tcp based --> always arrives after hello
                self.current_neighbors.add(sender)
                logging.info(f"Added {sender} to current neighbors")
                del self.outgoing_request[sender]
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            elif round > self.current_round:  # they are ahead
                self.future_neighbours.setdefault(round, []).append(
                    (sender, True)
                )  # True -> we initiated
                logging.info(f"Added {sender} to future neighbors")
                del self.outgoing_request[sender]  # cannot advance until this is empty
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            elif round == self.current_round - 1:  # we are ahead
                # should never arrive, as other round should advance and then send to us with current round
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                logging.critical(f"Received a hello from {sender} with {round}")
                raise RuntimeError(f"Received a hello from {sender} with {round}")
            else:
                logging.critical(f"Received a hello from {sender} with {round}")
                raise RuntimeError(f"Received a hello from {sender} with {round}")
        else:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            logging.debug(f"Connection request not initiated by us from {sender}")
            if (
                round == self.current_round
            ):  # other node will not advance as long as this request is not answered
                if sender not in self.peer_sockets:
                    self.connect(sender, initiating=False)
                self.current_neighbors.add(sender)
                logging.info(f"Added {sender} to current neighbors")
                # current data is none if sharing has not yet started for this round, i.e. we are still training locally
                if self.current_data != None:
                    self.send(
                        sender, self.current_data
                    )  # messaging is tcp based so this should always arrive after helo
            elif round > self.current_round:  # they are ahead
                self.future_neighbours.setdefault(round, []).append((sender, False))
                logging.info(f"Added {sender} to future neighbors")
                # we conntact them in the next local round
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            elif round < self.current_round:  # we are ahead
                # have to wait for them
                if sender not in self.peer_sockets:
                    self.connect(sender, initiating=False)
                self.current_neighbors.add(sender)
                logging.info(f"Added {sender} to current neighbors")
                if self.current_data != None:
                    self.send(
                        sender, self.current_data
                    )  # messaging is tcp based --> always arrives after hello

    def update_neighbors(self):
        # if not(len(self.future_neighbours) == 0 or (len(self.future_neighbours) == 1 and list(self.future_neighbours.keys())[0] == self.current_round )):
        #   print(f"current round {self.current_round}, and {list(self.future_neighbours.keys())}")
        if self.current_round in self.future_neighbours:
            for (sender, initiator) in self.future_neighbours[self.current_round]:
                if initiator:  # a connection we opened,
                    self.current_neighbors.add(sender)
                    logging.info(f"Added {sender} from future to current neighbors")
                else:
                    if sender not in self.peer_sockets:
                        self.connect(sender, initiating=False)
                    self.current_neighbors.add(sender)
                    logging.info(f"Added {sender} from future to current neighbors")
                # current data is none if sharing has not yet started for this round, i.e. we are still training locally
                if self.current_data != None:
                    self.send(sender, self.current_data)

            del self.future_neighbours[self.current_round]

    def send(self, uid, data):
        """
        Send a message to a process.

        Parameters
        ----------
        uid : any
            Additional information about the destination of the message
        data : dict
            Message as a Python dictionary

        """

        def send(uid, to_send):
            data_size = len(to_send)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
            self.total_bytes.value += data_size
            self.peer_sockets[uid].send(to_send)
            logging.debug("{} sent the message to {}.".format(self.uid, uid))
            logging.info("Sent data of size: {}".format(data_size))

        logging.debug("send: rw? {}".format(data.get("rw", False)))
        to_send = self.encrypt(data)
        if uid == "rw" and data.get("rw", False):
            # a rw message, they are shots into the blind, we need not track them,
            # at least no as long as we send them before the testing.
            # Else they could slow down the entire process as it takes some time to connect
            # So if we remove testing we should add a sleep for the designated connection reshuffling rounds
            if self.flag_running.value == 1:  # Do not send rw if we are shutting down
                # Problem: this node code still be running, but all other nodes have already send a bye
                # in this case this crashes because len(self.current_neighbors) == 0
                if len(self.current_neighbors) != 0:
                    uid = self.rw_sampler()(to_send)
                    logging.debug("send: rw to {}".format(uid))
                else:
                    logging.info("send: rw dropped due to having no neighbors")
            else:
                logging.debug("send: rw dropped due to being in shutdown")
                return
        elif type(uid) == tuple and uid[0] == "all":
            assert uid[1] == self.current_round
            self.current_data = data
            logging.debug(f"Sending to all neighbors {self.current_neighbors}")
            for n in self.current_neighbors:
                send(n, to_send)
            self.say_goodbye()  #
            return
        if uid in self.peer_sockets:
            send(uid, to_send)
        else:
            logging.info(
                f"{uid} was removed from the peer sockets probably because it already finished"
            )

    def say_goodbye(self, last=False):
        # is called at the beginning of the round before there was any opportunity to receive something
        if self.current_round in self.future_byes:
            logging.debug(f"processing future byes and adding them this rounds byes")
            futures = self.future_byes[self.current_round]
            for tup in futures:
                logging.debug(f"processing {tup}")
                if tup[1] == True:  # A bye that originated with us.
                    self.this_round_bye.add(tup[0])
                else:
                    self.peer_sockets[tup[0]].send(
                        self.encrypt((BYE, self.current_round))
                    )
                    self.this_round_bye.add(tup[0])
            del self.future_byes[self.current_round]
        if not last:
            if (
                len(self.current_neighbors) - len(self.this_round_bye)
                > self.neighbor_bound_upper
            ):
                logging.info(
                    f"Initiating a goodbye since have more neighbors {self.current_neighbors} than the upper bound"
                )
                # TODO: maybe track the neighbors by their age
                # TODO: do more than bye in a single round
                selected_neighbor = self.rw_sampler_equi(None)
                if selected_neighbor not in self.this_round_bye:
                    logging.info(f"initiating goodbye with {selected_neighbor}")
                    self.peer_sockets[selected_neighbor].send(
                        self.encrypt((BYE, self.current_round))
                    )
                    self.outgoing_byes[selected_neighbor] = self.current_round
                else:
                    logging.info(
                        f"initiating goodbye with {selected_neighbor} failed as we are already in the process of removing it"
                    )

    def disconnect_request_handler(self, sender, bye):
        # TODO: handle rejecting the bye
        # self.current_neighbors.remove(sender)
        if sender in self.peer_sockets:  # in deliver remove from this list
            if sender in self.outgoing_byes:
                # These are never rejected
                if self.outgoing_byes[sender] == bye[1]:
                    # since we initiated we already sent the data for this round
                    self.this_round_bye.add(sender)
                    # TODO: self.current_neighbors.remove(sender) in delivery
                    del self.outgoing_byes[sender]
                    logging.info(f"added {sender} to this round's byes")
                    # TODO: in delivery remove this_round_bye from self.peer_sockets
                    # TODO: in delivery cannot advance unless self.outgoing_byes is empty
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                elif self.outgoing_byes[sender] > bye[1]:  # other node is behind
                    # This should never happen as the other node should defer sending the bye until it is in our round
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    # However if both nodes send byes at the same time then this can occur.
                    # In this case the higher bye wins
                    logging.info(
                        f"Received a bye from {sender} with {bye[1]} (one behind us?)"
Jeffrey Wigger's avatar
Jeffrey Wigger committed

                    self.this_round_bye.add(sender)
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    r_was = self.outgoing_byes[sender]
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    del self.outgoing_byes[sender]
                    logging.info(f"added {sender} to this round's byes")

                    # raise RuntimeError(
                    #     f"Received a hello from {sender} with {bye[1]} (one behind us)"
                    # )
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    if r_was != bye[1] + 1:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        # Should not happen as they are connected
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        print(f"A bye was received at {r_was} from {sender} that is behind more than one {bye[1]}")
                        logging.info(f"A bye was received at {r_was} from {sender} that is behind more than one {bye[1]}")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                elif self.outgoing_byes[sender] < bye[1]:  # We are behind
                    # We know with certainty that the next round we disconnect (will still send the data)
                    # final disconnect is handled in delivery
                    # True again meaning that we initiated this!
Jeffrey Wigger's avatar
Jeffrey Wigger committed

                    self.future_byes.setdefault(bye[1], []).append(
                        (sender, True)
                    )  # TODO: makes sure all entries at bye[1] get deleted in delivery
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    logging.info(f"added {sender} to round {bye[1]} byes")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    r_was = self.outgoing_byes[sender]
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    del self.outgoing_byes[sender]
                    # need to remove it else cannot advance
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    if r_was != bye[1] - 1:
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        # Should not happen as they are connected
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        print(f"A bye was received at {r_was} from {sender} that is ahead more than one {bye[1]}")
                        logging.info(f"A bye was received at {r_was} from {sender} that is ahead more than one {bye[1]}")
            else:  # this goodbye was not initiated by us
                logging.debug(f"sender {sender} not in outgoing byes")
                all_to_be_removed = self.this_round_bye.union(
                    set(self.future_byes.keys())
                )
                if (
                    len(self.current_neighbors) - len(all_to_be_removed)
                    <= self.neighbor_bound_lower
                ):
                    print(
                        f"{self.uid} reached lower bound, current byes {self.this_round_bye} future byes {self.future_byes} current neighbours {self.current_neighbors} current round {self.current_round} and union {self.this_round_bye.union(set(self.future_byes.keys()))}"
                    )

                    self.peer_sockets[sender].send(
                        self.encrypt((NOBYE, self.current_round))
                    )
                else:
                    if bye[1] == self.current_round:
                        self.peer_sockets[sender].send(
                            self.encrypt((BYE, self.current_round))
                        )
                        # the other node will not advance until it receives our bye so we can safely disconnect
                        # TODO: self.current_neighbors.remove(sender) in delivery
                        # TODO: may not yet have sent the data, however we send data as soon as self.current_round is increased
                        self.this_round_bye.add(sender)
                        logging.info(f"added {sender} to this round's byes")
                        # TODO: in delivery remove this_round_bye from self.peer_sockets
                    elif bye[1] == self.current_round + 1:  # the other node is ahead
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        # TODO: move to < , >
                        # We need to advance and send the bye in the next round:
                        self.future_byes.setdefault(bye[1], []).append((sender, False))
                        logging.info(f"added {sender} to next round's byes")
                    elif bye[1] == self.current_round - 1:  # we are ahead
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        # TODO: move to < , >
                        # We have to wait for them
                        # TODO: not sure about this one
                        self.peer_sockets[sender].send(
                            self.encrypt((BYE, self.current_round))
                        )
                        # -> they have not yet send the data so we should keep it current_neighbors
                        # TODO: self.current_neighbors.remove(sender) in delivery
                        self.this_round_bye.add(sender)
                        logging.info(f"added {sender} to this round's byes")
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                    else:
                        logging.critical(f"disconnect request for {sender} with round {bye[1]}")
                        raise RuntimeError(
                            f"disconnect request for {sender} with round {bye[1]}"
                        )
            # self.peer_sockets[sender].close() # will linger until BYE is sent, if we reconnect shortly ...... X
            # del self.peer_sockets[sender]
        else:
            logging.info(
                f"Got a goodbye {bye} from a node {sender} that we are not connected to."
            )

    def disconnect_neighbors(self):
        """
        Disconnects all neighbors.

        """
        if not self.sent_disconnections:
            logging.info("Disconnecting neighbors")
            for uid, sock in self.peer_sockets.items():
                if uid in self.this_round_bye or uid in self.outgoing_byes:
                    # already sent a bye
                    continue
                sock.send(self.encrypt((BYE, self.current_round)))
                self.outgoing_byes[uid] = self.current_round
            self.sent_disconnections = True
            # The self.current_round gets increased once delivery happens, but we never call self.say_goodbye
            self.say_goodbye(last=True)
            while self.current_neighbors != self.this_round_bye:
                logging.debug(
                    f"current byes {self.this_round_bye} future byes {self.future_byes} current neighbours {self.current_neighbors} current round {self.current_round} and union {self.this_round_bye.union(set(self.future_byes.keys()))}"
                )
                sender, recv = self.router.recv_multipart()
                sender, recv = self.decrypt(sender, recv)
                if type(recv) == tuple and recv[0] == BYE:
                    logging.debug("Received {} from {}".format(BYE, sender))
                    logging.info(f"disconnect_neighbors: {self.uid} received bye from {sender}")
                    # self.current_neighbors.remove(sender)
                    self.disconnect_request_handler(sender, recv)
                elif type(recv) == tuple and recv[0] == NOBYE:
                    logging.debug(
Jeffrey Wigger's avatar
Jeffrey Wigger committed
                        f"{self.uid}received nobye in disconnect_neighbors {recv} from {sender}"
                    )
                    self.peer_sockets[sender].send(
                        self.encrypt((BYE, self.current_round))
                    )
                    time.sleep(0.05)
                elif type(recv) == tuple:
                    # this is a hello
                    logging.debug(f"{self.uid} Other {recv} from {sender}")
                else:
                    # this can happen now due to async
                    logging.info("Received unexpected message from {}".format(sender))
                    if recv.get("rw", False):
                        logging.info("Message was rw {}".format(recv["visited"]))
                    else:
                        logging.info("Message was normal")
                    # raise RuntimeError(
                    #    "Received a message when expecting BYE from {}".format(sender)
                    # )
Jeffrey Wigger's avatar
Jeffrey Wigger committed
        for sock in self.peer_sockets.values():
            sock.close()
        self.router.close()

    def rw_sampler_equi(self, message):

        index = torch.randint(
            0, len(self.current_neighbors), size=(1,), generator=self.random_generator
        ).item()
        logging.debug(
            "rw_sampler_equi selected index {} of {} {}".format(
                index, self.current_neighbors, type(self.current_neighbors)
            )
        )
        return list(self.current_neighbors)[index]

    def rw_sampler_equi_check_history(self, message):
        if message is None:  # RW starts from here
            index = torch.randint(
                0,
                len(self.current_neighbors),
                size=(1,),
                generator=self.random_generator,
            ).item()
            logging.debug(
                "rw_sampler_equi_check_history selected index {} of {} {}".format(
                    index, self.current_neighbors, type(self.current_neighbors)
                )
            )
            return list(self.current_neighbors)[index]
        else:
            visited = set(message["visited"])
            neighbors = self.current_neighbors  # is already a set
            possible_neigbors = neighbors.difference(visited)
            if len(possible_neigbors) == 0:
                return None
            else:
                index = torch.randint(
                    0,
                    len(possible_neigbors),
                    size=(1,),
                    generator=self.random_generator,
                ).item()
                return list(possible_neigbors)[index]

    def rw_sampler_mh(self, message):
        # Metro hastings version of the sampler
        # Samples the neighbour based on the MH weights, allows self loops (will just decrease fuel and reroll!)
        pass