import json
import logging
import os
from pathlib import Path

import numpy as np
import pywt
import torch
from decentralizepy.sharing.SharingDynamicGraph import SharingDynamicGraph
from decentralizepy.utils import conditional_value, identity

def change_transformer_wavelet(x, wavelet, level):
    """
    Transforms the model changes into wavelet frequency domain

    Parameters
    ----------
    x : torch.Tensor
        Model change in the space domain
    wavelet : str
        name of the wavelet to be used in gradient compression
    level: int
        name of the wavelet to be used in gradient compression

    Returns
    -------
    x : torch.Tensor
        Representation of the change int the wavelet domain
    """
    coeff = pywt.wavedec(x, wavelet, level=level)
    data, coeff_slices = pywt.coeffs_to_array(coeff)
    return torch.from_numpy(data.ravel())

class JwinsDynamicGraph(SharingDynamicGraph):
    """
    This class implements the vanilla version of partial model sharing.

    """

    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        log_dir,
        alpha=1.0,
        dict_ordered=True,
        save_shared=False,
        metadata_cap=1.0,
        wavelet="haar",
        level=4,
        change_based_selection=True,
        save_accumulated="",
        accumulation=False,
        accumulate_averaging_changes=False,
        rw_chance=1,
        rw_length=6,
        neighbor_bound=(3, 5),
        change_topo_interval=10,
        avg = False,
        lower_bound=0.1,
        metro_hastings=True,
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank
        machine_id : int
            Global machine id
        communication : decentralizepy.communication.Communication
            Communication module used to send and receive messages
        mapping : decentralizepy.mappings.Mapping
            Mapping (rank, machine_id) -> uid
        graph : decentralizepy.graphs.Graph
            Graph reprensenting neighbors
        model : decentralizepy.models.Model
            Model to train
        dataset : decentralizepy.datasets.Dataset
            Dataset for sharing data. Not implemented yet! TODO
        log_dir : str
            Location to write shared_params (only writing for 2 procs per machine)
        alpha : float
            Percentage of model to share
        dict_ordered : bool
            Specifies if the python dict maintains the order of insertion
        save_shared : bool
            Specifies if the indices of shared parameters should be logged
        metadata_cap : float
            Share full model when self.alpha > metadata_cap
        wavelet: str
            name of the wavelet to be used in gradient compression
        level: int
            name of the wavelet to be used in gradient compression
        change_based_selection : bool
            use frequency change to select topk frequencies
        save_accumulated : bool
            True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
            the accumulated change is stored.
        accumulation : bool
            True if the the indices to share should be selected based on accumulated frequency change
        save_accumulated : bool
            True if accumulated weight change should be written to file. In case of accumulation the accumulated change
            is stored. If a change_transformer is used then the transformed change is stored.
        change_transformer : (x: Tensor) -> Tensor
            A function that transforms the model change into other domains. Default: identity function
        accumulate_averaging_changes: bool
            True if the accumulation should account the model change due to averaging
        """
        self.wavelet = wavelet
        self.level = level

        super().__init__(
            rank, machine_id, communication, mapping, graph, model, dataset, log_dir, rw_chance, rw_length, neighbor_bound, change_topo_interval, avg
        )
        self.alpha = alpha
        self.dict_ordered = dict_ordered
        self.save_shared = save_shared
        self.metadata_cap = metadata_cap
        self.accumulation = accumulation
        self.save_accumulated = conditional_value(save_accumulated, "", False)
        self.change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level)
        self.accumulate_averaging_changes = accumulate_averaging_changes

        # getting the initial model
        self.shapes = []
        self.lens = []
        with torch.no_grad():
            tensors_to_cat = []
            for _, v in self.model.state_dict().items():
                self.shapes.append(v.shape)
                t = v.flatten()
                self.lens.append(t.shape[0])
                tensors_to_cat.append(t)
            self.init_model = torch.cat(tensors_to_cat, dim=0)
            if self.accumulation:
                self.model.accumulated_changes = torch.zeros_like(
                    self.change_transformer(self.init_model)
                )
                self.prev = self.init_model
        self.number_of_params = self.init_model.shape[0]
        if self.save_accumulated:
            self.model_change_path = os.path.join(
                self.log_dir, "model_change/{}".format(self.rank)
            )
            Path(self.model_change_path).mkdir(parents=True, exist_ok=True)

            self.model_val_path = os.path.join(
                self.log_dir, "model_val/{}".format(self.rank)
            )
            Path(self.model_val_path).mkdir(parents=True, exist_ok=True)

        # Only save for 2 procs: Save space
        if self.save_shared and not (rank == 0 or rank == 1):
            self.save_shared = False

        if self.save_shared:
            self.folder_path = os.path.join(
                self.log_dir, "shared_params/{}".format(self.rank)
            )
            Path(self.folder_path).mkdir(parents=True, exist_ok=True)

        self.model.shared_parameters_counter = torch.zeros(
            self.change_transformer(self.init_model).shape[0], dtype=torch.int32
        )

        self.change_based_selection = change_based_selection

        # Do a dummy transform to get the shape and coefficents slices
        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
        data, coeff_slices = pywt.coeffs_to_array(coeff)
        self.wt_shape = data.shape
        self.coeff_slices = coeff_slices

        self.lower_bound = lower_bound
        self.metro_hastings = metro_hastings
        if self.lower_bound > 0:
            self.start_lower_bounding_at = 1 / self.lower_bound

    def apply_wavelet(self):
        """
        Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
        based on the undergone change during the current training step

        Returns
        -------
        tuple
            (a,b). a: selected wavelet coefficients, b: Their indices.

        """

        data = self.pre_share_model_transformed
        if self.change_based_selection:
            diff = self.model.model_change
            _, index = torch.topk(
                diff.abs(),
                round(self.alpha * len(diff)),
                dim=0,
                sorted=False,
            )
        else:
            _, index = torch.topk(
                data.abs(),
                round(self.alpha * len(data)),
                dim=0,
                sorted=False,
            )
        ind, _ = torch.sort(index)

        if self.lower_bound == 0.0:
            return data[ind].clone().detach(), ind

        if self.communication_round > self.start_lower_bounding_at:
            # because the superclass increases it where it is inconvenient for this subclass
            currently_shared = self.model.shared_parameters_counter.clone().detach()
            currently_shared[ind] += 1
            ind_small = (
                currently_shared < self.communication_round * self.lower_bound
            ).nonzero(as_tuple=True)[0]
            ind_small_unique = np.setdiff1d(
                ind_small.numpy(), ind.numpy(), assume_unique=True
            )
            take_max = round(self.lower_bound * self.alpha * data.shape[0])
            logging.info(
                "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max
            )
            if take_max > ind_small_unique.shape[0]:
                take_max = ind_small_unique.shape[0]
            to_take = torch.rand(ind_small_unique.shape[0])
            _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False)
            ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take]
            logging.info("lower bounding: %i %i", len(ind), len(ind_bound))
            # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
            ind = torch.cat([ind, ind_bound])

        index, _ = torch.sort(ind)
        return data[index].clone().detach(), index

    def serialized_model(self):
        """
        Convert model to json dict. self.alpha specifies the fraction of model to send.

        Returns
        -------
        dict
            Model converted to json dict

        """
        m = dict()
        if self.alpha >= self.metadata_cap:  # Share fully
            data = self.pre_share_model_transformed
            m["params"] = data.numpy()
            if self.model.accumulated_changes is not None:
                self.model.accumulated_changes = torch.zeros_like(
                    self.model.accumulated_changes
                )
            return m

        with torch.no_grad():
            topk, self.G_topk = self.apply_wavelet()
            self.model.shared_parameters_counter[self.G_topk] += 1
            self.model.rewind_accumulation(self.G_topk)
            if self.save_shared:
                shared_params = dict()
                shared_params["order"] = list(self.model.state_dict().keys())
                shapes = dict()
                for k, v in self.model.state_dict().items():
                    shapes[k] = list(v.shape)
                shared_params["shapes"] = shapes

                shared_params[self.communication_round] = self.G_topk.tolist()

                with open(
                    os.path.join(
                        self.folder_path,
                        "{}_shared_params.json".format(self.communication_round + 1),
                    ),
                    "w",
                ) as of:
                    json.dump(shared_params, of)

            if not self.dict_ordered:
                raise NotImplementedError

            m["alpha"] = self.alpha

            m["indices"] = self.G_topk.numpy().astype(np.int32)

            m["params"] = topk.numpy()

            m["send_partial"] = True

            assert len(m["indices"]) == len(m["params"])
            logging.info("Elements sending: {}".format(len(m["indices"])))

            logging.info("Generated dictionary to send")

            logging.info("Converted dictionary to pickle")

            return m


    def deserialized_model(self, m):
        """
        Convert received dict to state_dict.

        Parameters
        ----------
        m : dict
            received dict

        Returns
        -------
        state_dict
            state_dict of received

        """
        ret = dict()
        if "send_partial" not in m:
            params = m["params"]
            params_tensor = torch.tensor(params)
            ret["params"] = params_tensor
            return ret

        with torch.no_grad():
            if not self.dict_ordered:
                raise NotImplementedError

            params_tensor = torch.tensor(m["params"])
            indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
            ret = dict()
            ret["indices"] = indices_tensor
            ret["params"] = params_tensor
            ret["send_partial"] = True
        return ret

    def _averaging(self):
        """
        Averages the received model with the local model

        """
        with torch.no_grad():
            total = None
            weight_total = 0
            wt_params = self.pre_share_model_transformed
            if not self.metro_hastings:
                weight_vector = torch.ones_like(wt_params)
                datas = []
            batch = self._preprocessing_received_models()
            for n, vals in batch.items():
                if len(vals) > 1:
                    # this should never happen
                    logging.info("GOT double message in dynamic graph!")
                else:
                    degree, iteration, data = vals[0]
                #degree, iteration, data = self.peer_deques[n].popleft()
                logging.debug(
                    "Averaging model from neighbor {} of iteration {}".format(
                        n, iteration
                    )
                )
                #data = self.deserialized_model(data)
                params = data["params"]
                if "indices" in data:
                    indices = data["indices"]
                    if not self.metro_hastings:
                        weight_vector[indices] += 1
                        topkwf = torch.zeros_like(wt_params)
                        topkwf[indices] = params
                        topkwf = topkwf.reshape(self.wt_shape)
                        datas.append(topkwf)
                    else:
                        # use local data to complement
                        topkwf = wt_params.clone().detach()
                        topkwf[indices] = params
                        topkwf = topkwf.reshape(self.wt_shape)

                else:
                    topkwf = params.reshape(self.wt_shape)
                    if not self.metro_hastings:
                        weight_vector += 1
                        datas.append(topkwf)

                if self.metro_hastings:
                    weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
                    weight_total += weight
                    if total is None:
                        total = weight * topkwf
                    else:
                        total += weight * topkwf
            if not self.metro_hastings:
                weight_vector = 1.0 / weight_vector
                # speed up by exploiting sparsity
                total = wt_params * weight_vector
                for d in datas:
                    total += d * weight_vector
            else:
                # Metro-Hastings
                total += (1 - weight_total) * wt_params

            avg_wf_params = pywt.array_to_coeffs(
                total.numpy(), self.coeff_slices, output_format="wavedec"
            )
            reverse_total = torch.from_numpy(
                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
            )

            start_index = 0
            std_dict = {}
            for i, key in enumerate(self.model.state_dict()):
                end_index = start_index + self.lens[i]
                std_dict[key] = reverse_total[start_index:end_index].reshape(
                    self.shapes[i]
                )
                start_index = end_index

        self.model.load_state_dict(std_dict)


    def _pre_step(self):
        """
        Called at the beginning of step.

        """
        logging.info("PartialModel _pre_step")
        with torch.no_grad():
            tensors_to_cat = [
                v.data.flatten() for _, v in self.model.state_dict().items()
            ]
            self.pre_share_model = torch.cat(tensors_to_cat, dim=0)
            # Would only need one of the transforms
            self.pre_share_model_transformed = self.change_transformer(
                self.pre_share_model
            )
            change = self.change_transformer(self.pre_share_model - self.init_model)
            if self.accumulation:
                if not self.accumulate_averaging_changes:
                    # Need to accumulate in _pre_step as the accumulation gets rewind during the step
                    self.model.accumulated_changes += change
                    change = self.model.accumulated_changes.clone().detach()
                else:
                    # For the legacy implementation, we will only rewind currently accumulated values
                    # and add the model change due to averaging in the end
                    change += self.model.accumulated_changes
            # stores change of the model due to training, change due to averaging is not accounted
            self.model.model_change = change

    def _post_step(self):
        """
        Called at the end of step.

        """
        logging.info("PartialModel _post_step")
        with torch.no_grad():
            tensors_to_cat = [
                v.data.flatten() for _, v in self.model.state_dict().items()
            ]
            post_share_model = torch.cat(tensors_to_cat, dim=0)
            self.init_model = post_share_model
            if self.accumulation:
                if self.accumulate_averaging_changes:
                    self.model.accumulated_changes += self.change_transformer(
                        self.init_model - self.prev
                    )
                self.prev = self.init_model
            self.model.model_change = None
        if self.save_accumulated:
            self.save_change()

    def save_vector(self, v, s):
        """
        Saves the given vector to the file.

        Parameters
        ----------
        v : torch.tensor
            The torch tensor to write to file
        s : str
            Path to folder to write to

        """
        output_dict = dict()
        output_dict["order"] = list(self.model.state_dict().keys())
        shapes = dict()
        for k, v1 in self.model.state_dict().items():
            shapes[k] = list(v1.shape)
        output_dict["shapes"] = shapes

        output_dict["tensor"] = v.tolist()

        with open(
            os.path.join(
                s,
                "{}.json".format(self.communication_round + 1),
            ),
            "w",
        ) as of:
            json.dump(output_dict, of)

    def save_change(self):
        """
        Saves the change and the gradient values for every iteration

        """
        self.save_vector(self.model.model_change, self.model_change_path)