import json
import logging
import os

import numpy
import torch

from decentralizepy.sharing.Sharing import Sharing


class PartialModel(Sharing):
    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        log_dir,
        alpha=1.0,
        dict_ordered=True,
    ):
        super().__init__(
            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
        )
        self.alpha = alpha
        self.dict_ordered = dict_ordered
        self.communication_round = 0

    def extract_top_gradients(self):
        logging.info("Summing up gradients")
        assert len(self.model.accumulated_gradients) > 0
        gradient_sum = self.model.accumulated_gradients[0]
        for i in range(1, len(self.model.accumulated_gradients)):
            for key in self.model.accumulated_gradients[i]:
                gradient_sum[key] += self.model.accumulated_gradients[i][key]

        logging.info("Returning topk gradients")
        tensors_to_cat = [v.data.flatten() for _, v in gradient_sum.items()]
        G_topk = torch.abs(torch.cat(tensors_to_cat, dim=0))
        return torch.topk(
            G_topk, round(self.alpha * G_topk.shape[0]), dim=0, sorted=False
        )

    def serialized_model(self):
        with torch.no_grad():
            _, G_topk = self.extract_top_gradients()

            if self.communication_round:
                with open(
                    os.path.join(
                        self.log_dir, "{}_shared_params.json".format(self.rank)
                    ),
                    "r",
                ) as inf:
                    shared_params = json.load(inf)
            else:
                shared_params = dict()
                shared_params["order"] = self.model.state_dict().keys()
                shapes = dict()
                for k, v in self.model.state_dict.items():
                    shapes[k] = v.shape.tolist()
                shared_params["shapes"] = shapes

            shared_params[self.communication_round] = G_topk.tolist()

            with open(
                os.path.join(self.log_dir, "{}_shared_params.json".format(self.rank)),
                "w",
            ) as of:
                json.dump(shared_params, of)

            logging.info("Extracting topk params")

            tensors_to_cat = [v.data.flatten() for v in self.model.parameters()]
            T = torch.cat(tensors_to_cat, dim=0)
            T_topk = T[G_topk]

            logging.info("Generating dictionary to send")

            m = dict()

            if not self.dict_ordered:
                raise NotImplementedError

            m["indices"] = G_topk.numpy().tolist()
            m["params"] = T_topk.numpy().tolist()

            assert len(m["indices"]) == len(m["params"])
            logging.info("Elements sending: {}".format(len(m["indices"])))

            logging.info("Generated dictionary to send")

            for key in m:
                m[key] = json.dumps(m[key])

            logging.info("Converted dictionary to json")

            return m

    def deserialized_model(self, m):
        with torch.no_grad():
            state_dict = self.model.state_dict()

            if not self.dict_ordered:
                raise NotImplementedError

            shapes = []
            lens = []
            tensors_to_cat = []
            for _, v in state_dict.items():
                shapes.append(v.shape)
                t = v.flatten()
                lens.append(t.shape[0])
                tensors_to_cat.append(t)

            T = torch.cat(tensors_to_cat, dim=0)
            index_tensor = torch.tensor(json.loads(m["indices"]))
            logging.debug("Original tensor: {}".format(T[index_tensor]))
            T[index_tensor] = torch.tensor(json.loads(m["params"]))
            logging.debug("Final tensor: {}".format(T[index_tensor]))
            start_index = 0
            for i, key in enumerate(state_dict):
                end_index = start_index + lens[i]
                state_dict[key] = T[start_index:end_index].reshape(shapes[i])
                start_index = end_index

            return state_dict