Newer
Older
Jeffrey Wigger
committed
import json
import logging
import lzma
import pickle
import time
import weakref
from collections import deque
Jeffrey Wigger
committed
from multiprocessing.sharedctypes import Value
from queue import Empty
import traceback
from multiprocessing import Lock
Jeffrey Wigger
committed
import torch
import zmq
Jeffrey Wigger
committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
faulthandler.enable()
import torch.multiprocessing as mp
from decentralizepy.communication.Communication import Communication
HELLO = b"HELLO"
BYE = b"BYE"
class TCPRandomWalkBase(Communication):
"""
TCPRandomWalkBase that copies only the encrypt and decrypt functions from TCP.py
This dummy is needed as the sharing interfaces can call encrpt and decrypt,
so we need the functions both in TCPRandomWalk and TCPRandomWalkInternal
"""
def addr(self, rank, machine_id):
"""
Returns TCP address of the process.
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
Returns
-------
str
Full address of the process using TCP
"""
machine_addr = self.ip_addrs[str(machine_id)]
Jeffrey Wigger
committed
return "tcp://{}:{}".format(machine_addr, port)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
addresses_filepath : str
JSON file with machine_id -> ip mapping
compression_package : str
Import path of a module that implements the compression.Compression.Compression class
compression_class : str
Name of the compression class inside the compression package
"""
super().__init__(rank, machine_id, mapping, total_procs)
self.addresses_filepath = addresses_filepath
self.compress = compress
self.offset = 10000 + offset
self.compression_package = compression_package
self.compression_class = compression_class
with open(addresses_filepath) as addrs:
self.ip_addrs = json.load(addrs)
self.identity = str(self.uid).encode()
if compression_package and compression_class:
compressor_module = importlib.import_module(compression_package)
compressor_class = getattr(compressor_module, compression_class)
self.compressor = compressor_class()
logging.info(f"Using the {compressor_class} to compress the data")
else:
assert not self.compress
self.total_data = 0 # dummy values
self.total_meta = 0
Jeffrey Wigger
committed
def encrypt(self, data):
"""
Encode data as python pickle.
Parameters
----------
data : dict
Data dict to send
Returns
-------
byte
Encoded data
"""
Jeffrey Wigger
committed
if self.compress:
logging.debug("in encrypt: compress")
if "indices" in data:
data["indices"] = self.compressor.compress(data["indices"])
assert "params" in data
data["params"] = self.compressor.compress_float(data["params"])
data_len = len(pickle.dumps(data["params"]))
Jeffrey Wigger
committed
output = pickle.dumps(data)
# the compressed meta data gets only a few bytes smaller after pickling
self.total_meta.value += (len(output) - data_len)
self.total_data.value += data_len
else:
logging.debug("in encrypt: else")
output = pickle.dumps(data)
# centralized testing uses its own instance
if type(data) == dict:
assert "params" in data
data_len = len(pickle.dumps(data["params"]))
self.total_meta.value += (len(output) - data_len)
self.total_data.value += data_len
Jeffrey Wigger
committed
return output
def decrypt(self, sender, data):
"""
Decode received pickle data.
Parameters
----------
sender : byte
sender of the data
data : byte
Data received
Returns
-------
tuple
(sender: int, data: dict)
"""
Jeffrey Wigger
committed
sender = int(sender.decode())
if self.compress:
logging.debug("in decrypt:comp")
data = pickle.loads(data)
if "indices" in data:
data["indices"] = self.compressor.decompress(data["indices"])
if "params" in data:
data["params"] = self.compressor.decompress_float(data["params"])
Jeffrey Wigger
committed
else:
Jeffrey Wigger
committed
data = pickle.loads(data)
return sender, data
class TCPRandomWalk(TCPRandomWalkBase):
"""
Wrapper for TCPRandomWalkInternal, mostly a copy of Communication
Connect
"""
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
Jeffrey Wigger
committed
sampler="equi",
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
"""
super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset, compression_package, compression_class)
Jeffrey Wigger
committed
self.sampler = sampler
self.send_queue = mp.Queue(1000)
self.recv_queue = mp.Queue(1000)
# Since we are only adding these do not need to share the same lock, intermediary results may be inconsistent
self.lock = mp.Lock()
self.total_data = mp.Value(c_long, 0, lock = self.lock)
self.total_meta = mp.Value(c_long, 0, lock = self.lock)
self.total_bytes = mp.Value(c_long, 0, lock = self.lock)
self.flag_running = mp.Value(c_int, 0, lock=False)
Jeffrey Wigger
committed
def connect_neighbors(self, neighbors):
"""
Spawns TCPRandomWalkInternal. It will connect to the neighbours.
This function should only be called once.
Parameters
----------
neighbors : list(int)
List of neighbors
"""
self.flag_running.value = 1
self.ctx = mp.start_processes(
lambda *args: TCPRandomWalkInternal(*(args[1:])),
args=(
self.rank,
self.machine_id,
self.mapping,
self.total_procs,
self.addresses_filepath,
self.send_queue,
self.recv_queue,
self.flag_running,
neighbors,
self.total_data,
self.total_meta,
self.total_bytes,
Jeffrey Wigger
committed
self.compress,
self.offset,
self.compression_package,
self.compression_class,
Jeffrey Wigger
committed
self.sampler,
),
start_method="fork",
join=False,
)
def receive(self, block=True, timeout=None):
"""
Returns a received message. It blocks if no message has been received.
Returns
----------
dict
Received and decrypted data
"""
Jeffrey Wigger
committed
try:
return self.recv_queue.get(block=block, timeout=None) # already decrypted
logging.debug("Receive in TCPRandomWalk; post get")
Jeffrey Wigger
committed
except Empty:
return None, None
Jeffrey Wigger
committed
"""
Send a message to a process.
Parameters
----------
uid : int
Neighbor's unique ID
data : dict
Message as a Python dictionary
"""
if uid is not None:
logging.debug("Send to %i in TCPRandomWalk", uid)
Jeffrey Wigger
committed
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
"""
#del self.lock
self.send_queue.join_thread()
self.recv_queue.join_thread()
print(f"disconnect_neighbors: joined {self.uid}")
Jeffrey Wigger
committed
class TCPRandomWalkInternal(TCPRandomWalkBase):
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
send_queue: mp.Queue,
recv_queue: mp.Queue,
flag_running,
neighbors,
Jeffrey Wigger
committed
compress=False,
offset=2000,
compression_package=None,
compression_class=None,
Jeffrey Wigger
committed
sampler="equi",
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
addresses_filepath : str
JSON file with machine_id -> ip mapping
"""
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
super().__init__(rank, machine_id, mapping, total_procs, addresses_filepath, compress, offset,
compression_package, compression_class)
print("post super")
try:
self.sampler = sampler
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.bind(self.addr(rank, machine_id))
self.sent_disconnections = False
self.compress = compress
self.total_data = total_data
self.total_meta = total_meta
self.total_bytes = total_bytes
self.send_queue = send_queue
self.recv_queue = recv_queue
self.flag_running = flag_running
self.neighbors = neighbors
self.random_generator = torch.Generator()
self.rw_seed = (
self.random_generator.seed()
) # new random seed from the random device
logging.info("Machine %i has random seed %i for RW", self.uid, self.rw_seed)
self.rw_messages_stat = []
self.rw_double_count_stat = []
if self.sampler == "equi":
logging.info("rw_samper is rw_sampler_equi")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
elif self.sampler == "equi_check_history":
logging.info("rw_samper is rw_sampler_equi_check_history")
self.rw_sampler = weakref.WeakMethod(self.rw_sampler_equi_check_history)
else:
logging.info("rw_samper is rw_sampler_equi (default)")
self.rw_sampler = weakref.WeakMethod(
self.rw_sampler_equi
) # self.rw_sampler_equi_check_history
self.peer_sockets = dict()
self.barrier = set()
self.connect_neighbors(self.neighbors)
self.comm_loop()
except BaseException as e:
error_message = traceback.format_exc()
print(error_message)
logging.debug("GOT EXCEPTION")
logging.debug(error_message)
while not self.recv_queue.empty():
print(f"{self.uid}: clear rcv")
_ = self.recv_queue.get_nowait()
self.recv_queue.join_thread()
print(f"{self.uid}: joined recv")
while not self.send_queue.empty():
print(f"{self.uid}: clear snd")
_ = self.send_queue.get_nowait()
print(f"{self.uid}: joined send")
Jeffrey Wigger
committed
def __del__(self):
"""
Destroys zmq context
"""
self.context.destroy(linger=0)
logging.info(f"TCPRandomWalkInternal for {self.uid} is destroyed")
Jeffrey Wigger
committed
# uncomment during debugging
#self.recv_queue.close()
#self.send_queue.close()
Jeffrey Wigger
committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
def connect_neighbors(self, neighbors):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
Caches any data received while waiting for HELLOs.
Parameters
----------
neighbors : list(int)
List of neighbors
Raises
------
RuntimeError
If received BYE while waiting for HELLO
"""
logging.info("Sending connection request to neighbors")
for uid in neighbors:
logging.debug("Connecting to my neighbour: {}".format(uid))
id = str(uid).encode()
req = self.context.socket(zmq.DEALER)
req.setsockopt(zmq.IDENTITY, self.identity)
req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
self.peer_sockets[id] = req
req.send(HELLO)
num_neighbors = len(neighbors)
while len(self.barrier) < num_neighbors:
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.debug("Received {} from {}".format(HELLO, sender))
self.barrier.add(sender)
elif recv == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
raise RuntimeError(
"A neighbour wants to disconnect before training started!"
)
else:
logging.debug(
"Received message from {} @ connect_neighbors".format(sender)
)
self.recv_queue.put(self.decrypt(sender, recv))
logging.info("Connected to all neighbors")
self.initialized = True
def comm_loop(self):
# TODO: May want separate loops for send and receive
sleep_time = 0.001
while True:
successes = 1
flushed = True
try:
sender, recv = self.router.recv_multipart(flags=zmq.NOBLOCK)
logging.debug("comm_loop received from %i", int(sender.decode()))
self.receive(sender, recv)
successes += 5
flushed = False
except zmq.ZMQError:
# logging.debug("zmq error due to empty recv_multipart")
sleep_time *= 1.1
try:
Jeffrey Wigger
committed
# send may block if send queue is full
if uid is not None:
logging.debug("comm_loop will send to %i", uid)
else:
logging.debug("comm_loop will send rw")
Jeffrey Wigger
committed
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
successes += 5
flushed = False
except Empty:
sleep_time *= 1.1
# logging.debug("empty send queue %i", self.send_queue.empty())
sleep_time /= successes
sleep_time = max(0.00001, sleep_time)
sleep_time = min(0.005, sleep_time)
# logging.debug("comm loop sleeping for %f", sleep_time)
time.sleep(sleep_time)
if self.flag_running.value == 0 and flushed:
break
self.disconnect_neighbors()
print(f"DISCONNECTED {self.uid}")
def receive(self, sender, recv):
"""
Returns ONE message received.
Returns
----------
dict
Received and decrypted data
Raises
------
RuntimeError
If received HELLO
"""
if recv == HELLO:
logging.debug("Received {} from {}".format(HELLO, sender))
# TODO: how to properly shut down here
self.flag_running.value = False
self.recv_queue.close()
self.send_queue.close()
raise RuntimeError(
"A neighbour wants to connect when everyone is connected!"
)
elif recv == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
else:
# TODO: here process receive of rw messages
logging.debug("Received message from {}".format(sender))
src, data = self.decrypt(sender, recv)
# ned new data object such that forwarding does not change the internal fields
# We add our selves to visited so, it could trigger "RW message was already once received"
new_data = data.copy()
if data.get("rw", False):
new_data["visited"] = new_data["visited"].copy()
src = new_data["visited"][
0
] # the original src is the neighbour that sent it to us.
logging.debug(
"Received message from {} was rw {}".format(sender, data["visited"])
)
if data["fuel"] > 0:
new_neighbor = self.rw_sampler()(data)
if new_neighbor == None:
logging.info(
"dropped rw message due to no new neigbor being available: %s",
str(data["visited"]),
)
else:
data["fuel"] -= 1
assert data["fuel"] + 1 == new_data["fuel"]
visited = data["visited"]
visited.append(self.uid)
data["visited"] = visited
logging.debug(
"Forward rw {} to {}".format(data["visited"], new_neighbor)
)
Jeffrey Wigger
committed
else:
# the message has no more fuel so it is dropped
logging.info("dropped rw message with fuel %i ", data["fuel"])
self.recv_queue.put((src, new_data))
logging.debug(
"Received message from {} after putting into the queue".format(sender)
)
Jeffrey Wigger
committed
"""
Send a message to a process.
Parameters
----------
uid : int
Neighbor's unique ID. If it is none then it means we sample the neighbour!
data : dict
Message as a Python dictionary
"""
assert self.initialized == True
if encrypt:
rw = data.get("rw", False)
logging.debug("send: rw? {}".format(rw))
to_send = self.encrypt(data)
else:
rw = False
logging.debug("send: rw? {}".format(rw))
to_send = data
Jeffrey Wigger
committed
data_size = len(to_send)
self.total_bytes.value += data_size
if uid is None and rw: # a rw message
Jeffrey Wigger
committed
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
if self.flag_running.value == 1: # Do not send rw if we are shutting down
uid = self.rw_sampler()(data)
logging.debug("send: rw to {}".format(uid))
else:
logging.debug("send: rw dropped due to being in shutdown")
return
id = str(uid).encode()
self.peer_sockets[id].send(to_send)
logging.debug("{} sent the message to {}.".format(self.uid, uid))
logging.info("Sent this round: {}".format(data_size))
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
"""
assert self.initialized == True
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for sock in self.peer_sockets.values():
sock.send(BYE)
self.sent_disconnections = True
while len(self.barrier):
sender, recv = self.router.recv_multipart()
if recv == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
else:
# this can happen now due to async
logging.info("Received unexpected message from {}".format(sender))
sender, data = self.decrypt(sender, recv)
if data.get("rw", False):
logging.info("Message was rw {}".format(data["visited"]))
else:
logging.info("Message was normal")
# raise RuntimeError(
# "Received a message when expecting BYE from {}".format(sender)
# )
for sock in self.peer_sockets.values():
sock.close()
self.router.close()
Jeffrey Wigger
committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def rw_sampler_equi(self, message):
index = torch.randint(
0, len(self.neighbors), size=(1,), generator=self.random_generator
).item()
logging.debug(
"rw_sampler_equi selected index {} of {} {}".format(
index, self.neighbors, type(self.neighbors)
)
)
return list(self.neighbors)[index]
def rw_sampler_equi_check_history(self, message):
if message is None: # RW starts from here
index = torch.randint(
0, len(self.neighbors), size=(1,), generator=self.random_generator
).item()
logging.debug(
"rw_sampler_equi_check_history selected index {} of {} {}".format(
index, self.neighbors, type(self.neighbors)
)
)
return list(self.neighbors)[index]
else:
visited = set(message["visited"])
neighbors = self.neighbors # is already a set
possible_neigbors = neighbors.difference(visited)
if len(possible_neigbors) == 0:
return None
else:
index = torch.randint(
0,
len(possible_neigbors),
size=(1,),
generator=self.random_generator,
).item()
return list(possible_neigbors)[index]
def rw_sampler_mh(self, message):
# Metro hastings version of the sampler
# Samples the neighbour based on the MH weights, allows self loops (will just decrease fuel and reroll!)
pass