Newer
Older
Jeffrey Wigger
committed
import faulthandler
import json
import logging
import lzma
import multiprocessing as mp
import pickle
import time
import weakref
from collections import deque
Jeffrey Wigger
committed
from multiprocessing.sharedctypes import Value
from queue import Empty
Jeffrey Wigger
committed
import torch
import zmq
from decentralizepy.communication.TCPRandomWalk import TCPRandomWalkBase
faulthandler.enable()
import torch.multiprocessing as mp
from decentralizepy.communication.Communication import Communication
HELLO = b"HELLO"
BYE = b"BYE"
NOBYE = b"NOBYE"
class TCPRandomWalkRouting(TCPRandomWalkBase):
"""
Wrapper for TCPRandomWalkInternal, mostly a copy of Communication
Connect
"""
Jeffrey Wigger
committed
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
Jeffrey Wigger
committed
sampler="equi",
neighbor_bound=(3, 5),
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
"""
super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class)
Jeffrey Wigger
committed
self.sampler = sampler
self.send_queue = mp.Queue(1000)
self.recv_queue = mp.Queue(1000)
(self.our_pipe, self.their_pipe) = mp.Pipe()
self.neighbors = None
self.neighbor_bound = neighbor_bound
self.lock = mp.Lock()
self.total_data = mp.Value(c_long, 0, lock = self.lock)
self.total_meta = mp.Value(c_long, 0, lock = self.lock)
self.total_bytes = mp.Value(c_long, 0, lock = self.lock)
self.flag_running = mp.Value(c_int, 0, lock=False)
Jeffrey Wigger
committed
def connect_neighbors(self, neighbors):
"""
Spawns TCPRandomWalkInternal. It will connect to the neighbours.
This function should only be called once.
Parameters
----------
neighbors : list(int)
List of neighbors
"""
self.flag_running.value = 1
self.ctx = mp.start_processes(
lambda *args: TCPRandomWalkRoutingInternal(*(args[1:])),
args=(
self.rank,
self.machine_id,
self.mapping,
self.total_procs,
self.addresses_filepath,
self.send_queue,
self.recv_queue,
self.their_pipe,
self.flag_running,
neighbors,
self.total_data,
self.total_meta,
self.total_bytes,
Jeffrey Wigger
committed
self.compress,
self.offset,
self.compression_package,
self.compression_class,
Jeffrey Wigger
committed
self.sampler,
self.neighbor_bound,
),
start_method="fork",
join=False,
)
def receive(self, block=True, timeout=None):
"""
Returns a received message. It blocks if no message has been received.
Returns
----------
dict
sender : data already decrypted
"""
Jeffrey Wigger
committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
try:
return self.recv_queue.get(block=block, timeout=None) # already decrypted
except Empty:
return None
def send(self, uid, data):
"""
Send a message to a process.
Parameters
----------
uid : int
Neighbor's unique ID
data : dict
Message as a Python dictionary
"""
if uid is not None:
logging.debug(f"Send to {uid} in TCPRandomWalk")
self.send_queue.put((uid, data))
def get_current_neighbors(self):
"""
Get the currently connected neighbors
Returns
----------
list
the neighbors
"""
Jeffrey Wigger
committed
if not self.our_pipe.poll():
return self.neighbors
else:
self.neighbors = self.our_pipe.recv()
return self.neighbors
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
"""
Jeffrey Wigger
committed
self.flag_running.value = 0
self.send_queue.close() # this crashes
self.recv_queue.close()
#del self.lock
self.send_queue.join_thread()
self.recv_queue.join_thread()
Jeffrey Wigger
committed
self.ctx.join()
Jeffrey Wigger
committed
class TCPRandomWalkRoutingInternal(TCPRandomWalkBase):
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
send_queue: mp.Queue,
recv_queue: mp.Queue,
pipe: mp.Pipe,
flag_running,
neighbors,
Jeffrey Wigger
committed
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
Jeffrey Wigger
committed
sampler="equi",
neighbor_bound=(3, 5),
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
addresses_filepath : str
JSON file with machine_id -> ip mapping
"""
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset,
compression_package, compression_class)
try:
self.sampler = sampler
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.bind(self.addr(rank, machine_id))
self.sent_disconnections = False
self.compress = compress
self.neighbor_bound_lower = neighbor_bound[0]
self.neighbor_bound_upper = neighbor_bound[1]
self.total_data = total_data
self.total_meta = total_meta
self.total_bytes = total_bytes
self.send_queue = send_queue
self.recv_queue = recv_queue
self.pipe = pipe
self.flag_running = flag_running
self.init_neighbors = neighbors
self.current_round = 0
self.current_data = None
self.future_neighbours = {}
self.outgoing_request = {}
Jeffrey Wigger
committed
Jeffrey Wigger
committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
self.init = False
self.outgoing_byes = {}
self.future_byes = {}
self.this_round_bye = set()
self.random_generator = torch.Generator()
self.rw_seed = (
self.random_generator.seed()
) # new random seed from the random device
logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed)
self.rw_messages_stat = []
self.rw_double_count_stat = []
if self.sampler == "equi":
logging.info("rw_samper is rw_sampler_equi")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
elif self.sampler == "equi_check_history":
logging.info("rw_samper is rw_sampler_equi_check_history")
self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history)
else:
logging.info("rw_samper is rw_sampler_equi (default)")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
self.peer_sockets = dict()
self.current_neighbors = set()
self.connect_neighbors(self.init_neighbors)
Jeffrey Wigger
committed
self.comm_loop()
except BaseException as e:
error_message = traceback.format_exc()
print(error_message)
logging.info("GOT EXCEPTION")
logging.info(error_message)
Jeffrey Wigger
committed
while not self.recv_queue.empty():
print(f"{self.uid}: clear rcv")
_ = self.recv_queue.get_nowait()
self.recv_queue.close()
self.recv_queue.join_thread()
print(f"{self.uid}: joined recv")
while not self.send_queue.empty():
print(f"{self.uid}: clear snd")
_ = self.send_queue.get_nowait()
print(f"{self.uid}: joined send")
self.send_queue.close()
self.send_queue.join_thread()
print("end")
Jeffrey Wigger
committed
def __del__(self):
"""
Destroys zmq context
"""
self.context.destroy(linger=0)
logging.info(f"TCPRandomWalkInternal for {self.uid} is destroyed")
# uncomment during debugging
#self.recv_queue.close()
#self.send_queue.close()
Jeffrey Wigger
committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
def connect_neighbors(self, neighbors):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
Caches any data received while waiting for HELLOs.
Parameters
----------
neighbors : list(int)
List of neighbors
Raises
------
RuntimeError
If received BYE while waiting for HELLO
"""
logging.info("Sending connection request to neighbors")
for uid in neighbors:
self.connect(uid)
def connect(self, uid, initiating=True):
to_send = (HELLO, self.current_round)
if uid in self.current_neighbors:
to_send = (HELLO, "fw at neighbor")
self.peer_sockets[uid].send(self.encrypt(to_send))
return
logging.debug("Connecting to new neighbour: {}".format(uid))
req = self.context.socket(zmq.DEALER)
req.setsockopt(zmq.IDENTITY, self.identity)
req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
self.peer_sockets[uid] = req
req.send(self.encrypt(to_send))
if initiating:
self.outgoing_request[uid] = self.current_round
def comm_loop(self):
# TODO: May want separate loops for send and receive
sleep_time = 0.001
while True:
flushed = True
successes = 1
try:
sender, recv = self.router.recv_multipart(flags=zmq.NOBLOCK)
logging.debug("comm_loop received from %i", int(sender.decode()))
self.receive(sender, recv)
successes += 5
flushed = False
except zmq.ZMQError:
# logging.debug("zmq error due to empty recv_multipart")
sleep_time *= 1.1
if self.init:
try:
uid, data = self.send_queue.get_nowait()
# send may block if send queue is full
if uid is not None:
logging.debug(f"comm_loop will send to {uid}")
else:
logging.debug("comm_loop will send rw")
self.send(uid, data)
successes += 5
flushed = False
except Empty:
sleep_time *= 1.1
# logging.debug("empty send queue %i", self.send_queue.empty())
self.update_neighbors()
self.can_deliver()
sleep_time /= successes
sleep_time = max(0.00001, sleep_time)
sleep_time = min(0.005, sleep_time)
# logging.debug("comm loop sleeping for %f", sleep_time)
time.sleep(sleep_time)
if self.flag_running.value == 0 and flushed:
break
self.disconnect_neighbors()
print(f"DISCONNECTED {self.uid}")
def receive(self, sender, recv):
"""
Returns ONE message received.
Returns
----------
dict
Received and decrypted data
Raises
------
RuntimeError
If received HELLO
"""
src, data = self.decrypt(sender, recv)
if type(data) == tuple and data[0] == HELLO:
logging.debug("Received {} from {}".format(HELLO, src))
if src in self.current_neighbors:
if data[1] != "fw at neighbor": # TODO: this is wrong
Jeffrey Wigger
committed
logging.critical(
"{} wants to connect when already connected!".format(HELLO, src)
)
raise RuntimeError(
"{} wants to connect when already connected!".format(HELLO, src)
)
else:
logging.info("fw arrived at a neighbour")
else:
logging.debug("Received {} from {}".format(data, src))
self.add_neighbors(src, data[1])
elif type(data) == tuple and data[0] == NOBYE:
print(f"{self.uid} received no bye {data} from {src}")
del self.outgoing_byes[src]
elif type(data) == tuple and data[0] == BYE:
logging.debug("Received {} from {}".format(BYE, src))
if src in self.current_neighbors:
self.disconnect_request_handler(src, data)
else:
found = False
print(f"future neighbours are {self.future_neighbours}")
keys = list(self.future_neighbours.keys())
for round in keys:
future_conns = self.future_neighbours[round]
Jeffrey Wigger
committed
print(f"future conns {future_conns} at round {round}")
if (src, False) in future_conns or (src, True) in future_conns:
print(
f"RECEIVED BYE for a future neighbor {src} for round {round}"
)
logging.info(
f"RECEIVED BYE for a future neighbor {src} for round {round}"
)
#if (src, False) in future_conns:
# future_conns.remove((src, False))
#else:
# future_conns.remove((src, True))
# Not calling disconnect_request_handler as we are not yet officially coneected
#self.peer_sockets[sender].send(
# self.encrypt((BYE, self.current_round))
#)
# assert data[1] == round
# above does not work as we also need to remove it from peer_sockets and
# the other side may have already received
self.future_byes.setdefault(data[1], []).append(
(src, False) # Bye did not originate here
Jeffrey Wigger
committed
)
found = True
#if len(future_conns) == 0:
# logging.info(
# f"There are now no new future neighbors in round {round}"
# )
# del self.future_neighbours[round]
Jeffrey Wigger
committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
if not found:
logging.critical(
"Received {} from {} despite it not being connected".format(
BYE, src
)
)
raise RuntimeError(
"Received {} from {} despite it not being connected".format(
BYE, src
)
)
else:
logging.debug("Received message from {}".format(src))
# ned new data object such that forwarding does not change the internal fields
# We add our selves to visited so, it could trigger "RW message was already once received"
new_data = data.copy()
if data.get("rw", False):
new_data["visited"] = new_data["visited"].copy()
src = new_data["visited"][
0
] # the original src is the neighbour that sent it to us.
logging.debug(
"Received message from {} was rw {}".format(src, data["visited"])
)
if data["fuel"] > 0:
new_neighbor = self.rw_sampler()(data)
if new_neighbor == None:
if src != self.uid:
logging.info(
"RW message is delivered here due to no new neighbors being available: %s",
str(data["visited"]),
)
# TODO: check if not already a neighbor
#self.connect(new_data["routing_info"])
Jeffrey Wigger
committed
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
return
else:
logging.info(
"RW message not deliver since it originated from here!*"
)
return
else:
data["fuel"] -= 1
assert data["fuel"] + 1 == new_data["fuel"]
visited = data["visited"]
visited.append(self.uid)
data["visited"] = visited
logging.debug(
"Forward rw {} to {}".format(data["visited"], new_neighbor)
)
self.send_queue.put((new_neighbor, data))
# We only deliver rw messages at the destination
return
else:
# the message has no more fuel so it is dropped
if src != self.uid:
# TODO: keep track of all rw we send!
# TODO: what if it is our neighbour?
logging.info(
"RW message from originator %i reached its destination with fuel %i ",
src,
data["fuel"],
)
self.connect(new_data["routing_info"])
return
else:
logging.info(
"RW message not deliver since it originated from here!"
)
return
else:
# self.recv_queue.put((src, new_data))
tmp = self.received_data.setdefault(new_data["iteration"], dict())
if src in tmp:
logging.critical(
"Received data from {} twice for iteration {}".format(
src, new_data["iteration"]
)
)
raise RuntimeError(
"Received data from {} twice for iteration {}".format(
src, new_data["iteration"]
)
)
else:
tmp[src] = new_data # todo: maybe store in deque
logging.debug(
"Received message from {} for iter {} after putting into ".format(
src, new_data["iteration"]
)
)
def can_deliver(self):
# check if only current neighbours are in it. and check that outgoing_request is empty and future_neighbours too
if self.current_data == None: # have not yet received this rounds data
return
if len(self.outgoing_request) != 0:
logging.debug(f"still have outgoing requests {self.outgoing_request}")
Jeffrey Wigger
committed
return
if len(self.future_neighbours.get(self.current_round, [])) != 0:
logging.debug(
f"still have future neighbours {self.future_neighbours.get(self.current_round, [])}"
)
return
if len(self.outgoing_byes) != 0:
logging.debug(f"still have outgoing byes {len(self.outgoing_byes)}")
return
currently_received = self.received_data.get(self.current_round, dict()).keys()
if len(currently_received) == 0:
logging.debug(f"No received data yet")
return
received_from = set(currently_received)
logging.debug(f"received data is {received_from} need: {self.current_neighbors}")
Jeffrey Wigger
committed
# Cannot test for equality here with self.current_neighbors since others may have already disconnected.
# TODO: make sure we got the model for the neighbours already removed due to by
if received_from == self.current_neighbors:
logging.debug(f"can deliver in round {self.current_round}")
for neighbor in self.this_round_bye:
logging.debug(f"remove neighbor {neighbor}")
self.current_neighbors.remove(neighbor)
self.peer_sockets[neighbor].close()
del self.peer_sockets[neighbor]
logging.info(
f"At the end of {self.current_round} round the neighbors are {self.current_neighbors}"
)
self.this_round_bye = set()
self.current_data = None
self.current_round += 1
self.recv_queue.put(self.received_data[self.current_round - 1])
# establishes a connection with future neighbors.
# must be called here, as some of these future neighbors may also be in future byes
self.update_neighbors()
Jeffrey Wigger
committed
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
del self.received_data[self.current_round - 1]
def add_neighbors(self, sender, round):
logging.info(f"Processing a HELLO from {sender} marked with {round}")
if self.current_round == 0:
if sender not in self.current_neighbors:
logging.info(f"Added {sender} to current neighbors")
self.current_neighbors.add(sender)
else:
logging.critical(f"Received a hello from {sender} with {round}")
raise RuntimeError(f"Received a hello from {sender} with {round}")
if self.current_neighbors == self.init_neighbors:
logging.info(f"Added all initial neighbours to current neighbors")
self.init = True
self.outgoing_request = dict()
return
if sender in self.outgoing_request:
if (
round == self.current_round
): # we do not advance as long as there outgoing requests
# current data is none if sharing has not yet started for this round, i.e. we are still training locally
if self.current_data != None:
self.send(
sender, self.current_data
) # messaging is tcp based --> always arrives after hello
self.current_neighbors.add(sender)
logging.info(f"Added {sender} to current neighbors")
del self.outgoing_request[sender]
Jeffrey Wigger
committed
self.future_neighbours.setdefault(round, []).append(
(sender, True)
) # True -> we initiated
logging.info(f"Added {sender} to future neighbors")
del self.outgoing_request[sender] # cannot advance until this is empty
elif round == self.current_round - 1: # we are ahead
Jeffrey Wigger
committed
# should never arrive, as other round should advance and then send to us with current round
logging.critical(f"Received a hello from {sender} with {round}")
raise RuntimeError(f"Received a hello from {sender} with {round}")
Jeffrey Wigger
committed
else:
logging.critical(f"Received a hello from {sender} with {round}")
raise RuntimeError(f"Received a hello from {sender} with {round}")
else:
logging.debug(f"Connection request not initiated by us from {sender}")
Jeffrey Wigger
committed
if (
round == self.current_round
): # other node will not advance as long as this request is not answered
if sender not in self.peer_sockets:
self.connect(sender, initiating=False)
self.current_neighbors.add(sender)
logging.info(f"Added {sender} to current neighbors")
# current data is none if sharing has not yet started for this round, i.e. we are still training locally
if self.current_data != None:
self.send(
sender, self.current_data
) # messaging is tcp based so this should always arrive after helo
elif round > self.current_round: # they are ahead
Jeffrey Wigger
committed
self.future_neighbours.setdefault(round, []).append((sender, False))
logging.info(f"Added {sender} to future neighbors")
# we conntact them in the next local round
Jeffrey Wigger
committed
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
# have to wait for them
if sender not in self.peer_sockets:
self.connect(sender, initiating=False)
self.current_neighbors.add(sender)
logging.info(f"Added {sender} to current neighbors")
if self.current_data != None:
self.send(
sender, self.current_data
) # messaging is tcp based --> always arrives after hello
def update_neighbors(self):
# if not(len(self.future_neighbours) == 0 or (len(self.future_neighbours) == 1 and list(self.future_neighbours.keys())[0] == self.current_round )):
# print(f"current round {self.current_round}, and {list(self.future_neighbours.keys())}")
if self.current_round in self.future_neighbours:
for (sender, initiator) in self.future_neighbours[self.current_round]:
if initiator: # a connection we opened,
self.current_neighbors.add(sender)
logging.info(f"Added {sender} from future to current neighbors")
else:
if sender not in self.peer_sockets:
self.connect(sender, initiating=False)
self.current_neighbors.add(sender)
logging.info(f"Added {sender} from future to current neighbors")
# current data is none if sharing has not yet started for this round, i.e. we are still training locally
if self.current_data != None:
self.send(sender, self.current_data)
del self.future_neighbours[self.current_round]
def send(self, uid, data):
"""
Send a message to a process.
Parameters
----------
uid : any
Additional information about the destination of the message
data : dict
Message as a Python dictionary
"""
def send(uid, to_send):
data_size = len(to_send)
Jeffrey Wigger
committed
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
self.peer_sockets[uid].send(to_send)
logging.debug("{} sent the message to {}.".format(self.uid, uid))
logging.info("Sent data of size: {}".format(data_size))
logging.debug("send: rw? {}".format(data.get("rw", False)))
to_send = self.encrypt(data)
if uid == "rw" and data.get("rw", False):
# a rw message, they are shots into the blind, we need not track them,
# at least no as long as we send them before the testing.
# Else they could slow down the entire process as it takes some time to connect
# So if we remove testing we should add a sleep for the designated connection reshuffling rounds
if self.flag_running.value == 1: # Do not send rw if we are shutting down
# Problem: this node code still be running, but all other nodes have already send a bye
# in this case this crashes because len(self.current_neighbors) == 0
if len(self.current_neighbors) != 0:
uid = self.rw_sampler()(to_send)
logging.debug("send: rw to {}".format(uid))
else:
logging.info("send: rw dropped due to having no neighbors")
else:
logging.debug("send: rw dropped due to being in shutdown")
return
elif type(uid) == tuple and uid[0] == "all":
assert uid[1] == self.current_round
self.current_data = data
logging.debug(f"Sending to all neighbors {self.current_neighbors}")
for n in self.current_neighbors:
send(n, to_send)
self.say_goodbye() #
return
if uid in self.peer_sockets:
send(uid, to_send)
else:
logging.info(
f"{uid} was removed from the peer sockets probably because it already finished"
)
def say_goodbye(self, last=False):
# is called at the beginning of the round before there was any opportunity to receive something
if self.current_round in self.future_byes:
logging.debug(f"processing future byes and adding them this rounds byes")
futures = self.future_byes[self.current_round]
for tup in futures:
logging.debug(f"processing {tup}")
if tup[1] == True: # A bye that originated with us.
self.this_round_bye.add(tup[0])
else:
self.peer_sockets[tup[0]].send(
self.encrypt((BYE, self.current_round))
)
self.this_round_bye.add(tup[0])
del self.future_byes[self.current_round]
if not last:
if (
len(self.current_neighbors) - len(self.this_round_bye)
> self.neighbor_bound_upper
):
logging.info(
f"Initiating a goodbye since have more neighbors {self.current_neighbors} than the upper bound"
)
# TODO: maybe track the neighbors by their age
# TODO: do more than bye in a single round
selected_neighbor = self.rw_sampler_equi(None)
if selected_neighbor not in self.this_round_bye:
logging.info(f"initiating goodbye with {selected_neighbor}")
self.peer_sockets[selected_neighbor].send(
self.encrypt((BYE, self.current_round))
)
self.outgoing_byes[selected_neighbor] = self.current_round
else:
logging.info(
f"initiating goodbye with {selected_neighbor} failed as we are already in the process of removing it"
)
def disconnect_request_handler(self, sender, bye):
# TODO: handle rejecting the bye
# self.current_neighbors.remove(sender)
if sender in self.peer_sockets: # in deliver remove from this list
if sender in self.outgoing_byes:
# These are never rejected
if self.outgoing_byes[sender] == bye[1]:
# since we initiated we already sent the data for this round
self.this_round_bye.add(sender)
# TODO: self.current_neighbors.remove(sender) in delivery
del self.outgoing_byes[sender]
logging.info(f"added {sender} to this round's byes")
# TODO: in delivery remove this_round_bye from self.peer_sockets
# TODO: in delivery cannot advance unless self.outgoing_byes is empty
elif self.outgoing_byes[sender] > bye[1]: # other node is behind
Jeffrey Wigger
committed
# This should never happen as the other node should defer sending the bye until it is in our round
# However if both nodes send byes at the same time then this can occur.
# In this case the higher bye wins
logging.info(
f"Received a bye from {sender} with {bye[1]} (one behind us?)"
Jeffrey Wigger
committed
)
del self.outgoing_byes[sender]
logging.info(f"added {sender} to this round's byes")
# raise RuntimeError(
# f"Received a hello from {sender} with {bye[1]} (one behind us)"
# )
print(f"A bye was received at {r_was} from {sender} that is behind more than one {bye[1]}")
logging.info(f"A bye was received at {r_was} from {sender} that is behind more than one {bye[1]}")
Jeffrey Wigger
committed
# We know with certainty that the next round we disconnect (will still send the data)
# final disconnect is handled in delivery
# True again meaning that we initiated this!
Jeffrey Wigger
committed
self.future_byes.setdefault(bye[1], []).append(
(sender, True)
) # TODO: makes sure all entries at bye[1] get deleted in delivery
del self.outgoing_byes[sender]
# need to remove it else cannot advance
print(f"A bye was received at {r_was} from {sender} that is ahead more than one {bye[1]}")
logging.info(f"A bye was received at {r_was} from {sender} that is ahead more than one {bye[1]}")
Jeffrey Wigger
committed
else: # this goodbye was not initiated by us
logging.debug(f"sender {sender} not in outgoing byes")
Jeffrey Wigger
committed
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
all_to_be_removed = self.this_round_bye.union(
set(self.future_byes.keys())
)
if (
len(self.current_neighbors) - len(all_to_be_removed)
<= self.neighbor_bound_lower
):
print(
f"{self.uid} reached lower bound, current byes {self.this_round_bye} future byes {self.future_byes} current neighbours {self.current_neighbors} current round {self.current_round} and union {self.this_round_bye.union(set(self.future_byes.keys()))}"
)
self.peer_sockets[sender].send(
self.encrypt((NOBYE, self.current_round))
)
else:
if bye[1] == self.current_round:
self.peer_sockets[sender].send(
self.encrypt((BYE, self.current_round))
)
# the other node will not advance until it receives our bye so we can safely disconnect
# TODO: self.current_neighbors.remove(sender) in delivery
# TODO: may not yet have sent the data, however we send data as soon as self.current_round is increased
self.this_round_bye.add(sender)
logging.info(f"added {sender} to this round's byes")
# TODO: in delivery remove this_round_bye from self.peer_sockets
elif bye[1] == self.current_round + 1: # the other node is ahead
Jeffrey Wigger
committed
# We need to advance and send the bye in the next round:
self.future_byes.setdefault(bye[1], []).append((sender, False))
logging.info(f"added {sender} to next round's byes")
elif bye[1] == self.current_round - 1: # we are ahead
Jeffrey Wigger
committed
# We have to wait for them
# TODO: not sure about this one
self.peer_sockets[sender].send(
self.encrypt((BYE, self.current_round))
)
# -> they have not yet send the data so we should keep it current_neighbors
# TODO: self.current_neighbors.remove(sender) in delivery
self.this_round_bye.add(sender)
logging.info(f"added {sender} to this round's byes")
else:
logging.critical(f"disconnect request for {sender} with round {bye[1]}")
raise RuntimeError(
f"disconnect request for {sender} with round {bye[1]}"
)
Jeffrey Wigger
committed
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
# self.peer_sockets[sender].close() # will linger until BYE is sent, if we reconnect shortly ...... X
# del self.peer_sockets[sender]
else:
logging.info(
f"Got a goodbye {bye} from a node {sender} that we are not connected to."
)
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
"""
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for uid, sock in self.peer_sockets.items():
if uid in self.this_round_bye or uid in self.outgoing_byes:
# already sent a bye
continue
sock.send(self.encrypt((BYE, self.current_round)))
self.outgoing_byes[uid] = self.current_round
self.sent_disconnections = True
# The self.current_round gets increased once delivery happens, but we never call self.say_goodbye
self.say_goodbye(last=True)
while self.current_neighbors != self.this_round_bye:
logging.debug(
f"current byes {self.this_round_bye} future byes {self.future_byes} current neighbours {self.current_neighbors} current round {self.current_round} and union {self.this_round_bye.union(set(self.future_byes.keys()))}"
)
sender, recv = self.router.recv_multipart()
sender, recv = self.decrypt(sender, recv)
if type(recv) == tuple and recv[0] == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
logging.info(f"disconnect_neighbors: {self.uid} received bye from {sender}")
Jeffrey Wigger
committed
# self.current_neighbors.remove(sender)
self.disconnect_request_handler(sender, recv)
elif type(recv) == tuple and recv[0] == NOBYE:
f"{self.uid}received nobye in disconnect_neighbors {recv} from {sender}"
Jeffrey Wigger
committed
)
self.peer_sockets[sender].send(
self.encrypt((BYE, self.current_round))
)
time.sleep(0.05)
elif type(recv) == tuple:
# this is a hello
logging.debug(f"{self.uid} Other {recv} from {sender}")
Jeffrey Wigger
committed
else:
# this can happen now due to async
logging.info("Received unexpected message from {}".format(sender))
if recv.get("rw", False):
logging.info("Message was rw {}".format(recv["visited"]))
else:
logging.info("Message was normal")
# raise RuntimeError(
# "Received a message when expecting BYE from {}".format(sender)
# )
for sock in self.peer_sockets.values():
sock.close()
self.router.close()
Jeffrey Wigger
committed
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
def rw_sampler_equi(self, message):
index = torch.randint(
0, len(self.current_neighbors), size=(1,), generator=self.random_generator
).item()
logging.debug(
"rw_sampler_equi selected index {} of {} {}".format(
index, self.current_neighbors, type(self.current_neighbors)
)
)
return list(self.current_neighbors)[index]
def rw_sampler_equi_check_history(self, message):
if message is None: # RW starts from here
index = torch.randint(
0,
len(self.current_neighbors),
size=(1,),
generator=self.random_generator,
).item()
logging.debug(
"rw_sampler_equi_check_history selected index {} of {} {}".format(
index, self.current_neighbors, type(self.current_neighbors)
)
)
return list(self.current_neighbors)[index]
else:
visited = set(message["visited"])
neighbors = self.current_neighbors # is already a set
possible_neigbors = neighbors.difference(visited)
if len(possible_neigbors) == 0:
return None
else:
index = torch.randint(
0,
len(possible_neigbors),
size=(1,),
generator=self.random_generator,
).item()
return list(possible_neigbors)[index]
def rw_sampler_mh(self, message):
# Metro hastings version of the sampler
# Samples the neighbour based on the MH weights, allows self loops (will just decrease fuel and reroll!)
pass