-
Jeffrey Wigger authoredJeffrey Wigger authored
TCP.py 9.01 KiB
import importlib
import json
import logging
import pickle
from collections import deque
import zmq
from decentralizepy.communication.Communication import Communication
HELLO = b"HELLO"
BYE = b"BYE"
class TCP(Communication):
"""
TCP Communication API
"""
def addr(self, rank, machine_id):
"""
Returns TCP address of the process.
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
Returns
-------
str
Full address of the process using TCP
"""
machine_addr = self.ip_addrs[str(machine_id)]
port = rank + self.offset
return "tcp://{}:{}".format(machine_addr, port)
def __init__(
self,
rank,
machine_id,
mapping,
total_procs,
addresses_filepath,
compress=False,
offset=20000,
compression_package=None,
compression_class=None,
):
"""
Constructor
Parameters
----------
rank : int
Local rank of the process
machine_id : int
Machine id of the process
mapping : decentralizepy.mappings.Mapping
uid, rank, machine_id invertible mapping
total_procs : int
Total number of processes
addresses_filepath : str
JSON file with machine_id -> ip mapping
compression_package : str
Import path of a module that implements the compression.Compression.Compression class
compression_class : str
Name of the compression class inside the compression package
"""
super().__init__(rank, machine_id, mapping, total_procs)
with open(addresses_filepath) as addrs:
self.ip_addrs = json.load(addrs)
self.total_procs = total_procs
self.rank = rank
self.machine_id = machine_id
self.mapping = mapping
self.offset = 20000 + offset
self.uid = mapping.get_uid(rank, machine_id)
self.identity = str(self.uid).encode()
self.context = zmq.Context()
self.router = self.context.socket(zmq.ROUTER)
self.router.setsockopt(zmq.IDENTITY, self.identity)
self.router.bind(self.addr(rank, machine_id))
self.sent_disconnections = False
self.compress = compress
if compression_package and compression_class:
compressor_module = importlib.import_module(compression_package)
compressor_class = getattr(compressor_module, compression_class)
self.compressor = compressor_class()
logging.info(f"Using the {compressor_class} to compress the data")
else:
assert not self.compress
self.total_data = 0
self.total_meta = 0
self.peer_deque = deque()
self.peer_sockets = dict()
self.barrier = set()
def __del__(self):
"""
Destroys zmq context
"""
self.context.destroy(linger=0)
def encrypt(self, data):
"""
Encode data as python pickle.
Parameters
----------
data : dict
Data dict to send
Returns
-------
byte
Encoded data
"""
if self.compress:
if "indices" in data:
data["indices"] = self.compressor.compress(data["indices"])
assert "params" in data
data["params"] = self.compressor.compress_float(data["params"])
data_len = len(pickle.dumps(data["params"]))
output = pickle.dumps(data)
# the compressed meta data gets only a few bytes smaller after pickling
self.total_meta += len(output) - data_len
self.total_data += data_len
else:
output = pickle.dumps(data)
# centralized testing uses its own instance
if type(data) == dict:
assert "params" in data
data_len = len(pickle.dumps(data["params"]))
self.total_meta += len(output) - data_len
self.total_data += data_len
return output
def decrypt(self, sender, data):
"""
Decode received pickle data.
Parameters
----------
sender : byte
sender of the data
data : byte
Data received
Returns
-------
tuple
(sender: int, data: dict)
"""
sender = int(sender.decode())
if self.compress:
data = pickle.loads(data)
if "indices" in data:
data["indices"] = self.compressor.decompress(data["indices"])
if "params" in data:
data["params"] = self.compressor.decompress_float(data["params"])
else:
data = pickle.loads(data)
return sender, data
def connect_neighbors(self, neighbors):
"""
Connects all neighbors. Sends HELLO. Waits for HELLO.
Caches any data received while waiting for HELLOs.
Parameters
----------
neighbors : list(int)
List of neighbors
Raises
------
RuntimeError
If received BYE while waiting for HELLO
"""
logging.info("Sending connection request to neighbors")
for uid in neighbors:
logging.debug("Connecting to my neighbour: {}".format(uid))
id = str(uid).encode()
req = self.context.socket(zmq.DEALER)
req.setsockopt(zmq.IDENTITY, self.identity)
req.connect(self.addr(*self.mapping.get_machine_and_rank(uid)))
self.peer_sockets[id] = req
req.send(HELLO)
num_neighbors = len(neighbors)
while len(self.barrier) < num_neighbors:
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.debug("Received {} from {}".format(HELLO, sender))
self.barrier.add(sender)
elif recv == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
raise RuntimeError(
"A neighbour wants to disconnect before training started!"
)
else:
logging.debug(
"Received message from {} @ connect_neighbors".format(sender)
)
self.peer_deque.append(self.decrypt(sender, recv))
logging.info("Connected to all neighbors")
self.initialized = True
def receive(self):
"""
Returns ONE message received.
Returns
----------
dict
Received and decrypted data
Raises
------
RuntimeError
If received HELLO
"""
assert self.initialized == True
if len(self.peer_deque) != 0:
resp = self.peer_deque.popleft()
return resp
sender, recv = self.router.recv_multipart()
if recv == HELLO:
logging.debug("Received {} from {}".format(HELLO, sender))
raise RuntimeError(
"A neighbour wants to connect when everyone is connected!"
)
elif recv == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
return self.receive()
else:
logging.debug("Received message from {}".format(sender))
return self.decrypt(sender, recv)
def send(self, uid, data, encrypt=True):
"""
Send a message to a process.
Parameters
----------
uid : int
Neighbor's unique ID
data : dict
Message as a Python dictionary
"""
assert self.initialized == True
if encrypt:
to_send = self.encrypt(data)
else:
to_send = data
data_size = len(to_send)
self.total_bytes += data_size
id = str(uid).encode()
self.peer_sockets[id].send(to_send)
logging.debug("{} sent the message to {}.".format(self.uid, uid))
logging.info("Sent this round: {}".format(data_size))
def disconnect_neighbors(self):
"""
Disconnects all neighbors.
"""
assert self.initialized == True
if not self.sent_disconnections:
logging.info("Disconnecting neighbors")
for sock in self.peer_sockets.values():
sock.send(BYE)
self.sent_disconnections = True
while len(self.barrier):
sender, recv = self.router.recv_multipart()
if recv == BYE:
logging.debug("Received {} from {}".format(BYE, sender))
self.barrier.remove(sender)
else:
logging.critical(
"Received unexpected {} from {}".format(recv, sender)
)
raise RuntimeError(
"Received a message when expecting BYE from {}".format(sender)
)