import json import logging import math import random import torch from decentralizepy.sharing.Sharing import Sharing class RoundRobinPartial(Sharing): """ This class implements the Round robin partial model sharing. """ def __init__( self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir, alpha=1.0, ): """ Constructor Parameters ---------- rank : int Local rank machine_id : int Global machine id communication : decentralizepy.communication.Communication Communication module used to send and receive messages mapping : decentralizepy.mappings.Mapping Mapping (rank, machine_id) -> uid graph : decentralizepy.graphs.Graph Graph reprensenting neighbors model : decentralizepy.models.Model Model to train dataset : decentralizepy.datasets.Dataset Dataset for sharing data. Not implemented yet. log_dir : str Location to write shared_params (only writing for 2 procs per machine) alpha : float Percentage of model to share """ super().__init__( rank, machine_id, communication, mapping, graph, model, dataset, log_dir ) self.alpha = alpha random.seed(self.mapping.get_uid(rank, machine_id)) n_params = self.model.count_params() logging.info("Total number of parameters: {}".format(n_params)) self.block_size = math.ceil(self.alpha * n_params) logging.info("Block_size: {}".format(self.block_size)) self.num_blocks = math.ceil(n_params / self.block_size) logging.info("Total number of blocks: {}".format(n_params)) self.current_block = random.randint(0, self.num_blocks - 1) def serialized_model(self): """ Convert model to json dict. self.alpha specifies the fraction of model to send. Returns ------- dict Model converted to json dict """ with torch.no_grad(): logging.info("Extracting params to send") tensors_to_cat = [v.data.flatten() for v in self.model.parameters()] T = torch.cat(tensors_to_cat, dim=0) block_start = self.current_block * self.block_size block_end = min(T.shape[0], (self.current_block + 1) * self.block_size) self.current_block = (self.current_block + 1) % self.num_blocks T_send = T[block_start:block_end] logging.info("Range sending: {}-{}".format(block_start, block_end)) logging.info("Generating dictionary to send") m = dict() m["block_start"] = block_start m["block_end"] = block_end m["params"] = T_send.numpy().tolist() logging.info("Elements sending: {}".format(len(m["params"]))) logging.info("Generated dictionary to send") for key in m: m[key] = json.dumps(m[key]) logging.info("Converted dictionary to json") self.total_data += len(self.communication.encrypt(m["params"])) return m def deserialized_model(self, m): """ Convert received json dict to state_dict. Parameters ---------- m : dict json dict received Returns ------- state_dict state_dict of received """ with torch.no_grad(): state_dict = self.model.state_dict() shapes = [] lens = [] tensors_to_cat = [] for _, v in state_dict.items(): shapes.append(v.shape) t = v.flatten() lens.append(t.shape[0]) tensors_to_cat.append(t) T = torch.cat(tensors_to_cat, dim=0) block_start = json.loads(m["block_start"]) block_end = json.loads(m["block_end"]) T[block_start:block_end] = torch.tensor(json.loads(m["params"])) start_index = 0 for i, key in enumerate(state_dict): end_index = start_index + lens[i] state_dict[key] = T[start_index:end_index].reshape(shapes[i]) start_index = end_index return state_dict