Skip to content
Snippets Groups Projects
PartialModel.py 4.11 KiB
Newer Older
import json
Rishi Sharma's avatar
Rishi Sharma committed
import logging

import numpy
import torch

Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy.sharing.Sharing import Sharing

Rishi Sharma's avatar
Rishi Sharma committed
class PartialModel(Sharing):
    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        alpha=1.0,
        dict_ordered=True,
    ):
        super().__init__(
            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
        )
        self.alpha = alpha
        self.dict_ordered = dict_ordered
        self.communication_round = 0
    def extract_top_gradients(self):
Rishi Sharma's avatar
Rishi Sharma committed
        logging.info("Summing up gradients")
        assert len(self.model.accumulated_gradients) > 0
        gradient_sum = self.model.accumulated_gradients[0]
        for i in range(1, len(self.model.accumulated_gradients)):
            for key in self.model.accumulated_gradients[i]:
                gradient_sum[key] += self.model.accumulated_gradients[i][key]

        logging.info("Returning topk gradients")
        tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
        G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
        return torch.topk(
            G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
        )
    def serialized_model(self):
        with torch.no_grad():
            _, G_topk = self.extract_top_gradients()

            if self.communication_round:
                with open(
                    os.path.join(
                        self.log_dir, "{}_shared_params.json".format(self.rank)
                    ),
                    "r",
                ) as inf:
                    shared_params = json.load(inf)
            else:
                shared_params = dict()
                shared_params["order"] = self.model.state_dict().keys()
                shapes = dict()
                for k, v in self.model.state_dict.items():
                    shapes[k] = v.shape.tolist()
                shared_params["shapes"] = shapes

            shared_params[self.communication_round] = G_topk.tolist()

            with open(
                os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)),
                "w",
            ) as of:
                json.dump(shared_params, of)

            logging.info("Extracting topk params")
            tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
            T = torch.cat(tensors_to_cat, dim=0)
            T_topk = T[G_topk]
            logging.info("Generating dictionary to send")
            m = dict()
            if not self.dict_ordered:
                raise NotImplementedError
            m["indices"] = G_topk.numpy().tolist()
            m["params"] = T_topk.numpy().tolist()
            assert len(m["indices"]) == len(m["params"])
            logging.info("Elements sending: {}".format(len(m["indices"])))
            logging.info("Generated dictionary to send")
            for key in m:
                m[key] = json.dumps(m[key])
            logging.info("Converted dictionary to json")

            return m
    def deserialized_model(self, m):
        with torch.no_grad():
            state_dict = self.model.state_dict()

            if not self.dict_ordered:
                raise NotImplementedError

            shapes = []
            lens = []
            tensors_to_cat = []
            for _, v in state_dict.items():
                shapes.append(v.shape)
                t = v.flatten()
                lens.append(t.shape[0])
                tensors_to_cat.append(t)

            T = torch.cat(tensors_to_cat, dim=0)
            index_tensor = torch.tensor(json.loads(m["indices"]))
            logging.debug("Original tensor: {}".format(T[index_tensor]))
            T[index_tensor] = torch.tensor(json.loads(m["params"]))
            logging.debug("Final tensor: {}".format(T[index_tensor]))
            start_index = 0
            for i, key in enumerate(state_dict):
                end_index = start_index + lens[i]
                state_dict[key] = T[start_index:end_index].reshape(shapes[i])
                start_index = end_index

            return state_dict