import json
import logging
import os
from pathlib import Path
from time import time

import numpy as np
import pywt
import torch

from decentralizepy.sharing.LowerBoundTopK import LowerBoundTopK


def change_transformer_wavelet(x, wavelet, level):
    """
    Transforms the model changes into wavelet frequency domain

    Parameters
    ----------
    x : torch.Tensor
        Model change in the space domain
    wavelet : str
        name of the wavelet to be used in gradient compression
    level: int
        name of the wavelet to be used in gradient compression

    Returns
    -------
    x : torch.Tensor
        Representation of the change int the wavelet domain
    """
    coeff = pywt.wavedec(x, wavelet, level=level)
    data, coeff_slices = pywt.coeffs_to_array(coeff)
    return torch.from_numpy(data.ravel())


class WaveletBound(LowerBoundTopK):
    """
    This class implements the wavelet version of model sharing
    It is based on PartialModel.py

    """

    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        log_dir,
        alpha=1.0,
        dict_ordered=True,
        save_shared=False,
        metadata_cap=1.0,
        wavelet="haar",
        level=4,
        change_based_selection=True,
        save_accumulated="",
        accumulation=False,
        accumulate_averaging_changes=False,
        lower_bound=0.1,
        metro_hastings=True,
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank
        machine_id : int
            Global machine id
        communication : decentralizepy.communication.Communication
            Communication module used to send and receive messages
        mapping : decentralizepy.mappings.Mapping
            Mapping (rank, machine_id) -> uid
        graph : decentralizepy.graphs.Graph
            Graph reprensenting neighbors
        model : decentralizepy.models.Model
            Model to train
        dataset : decentralizepy.datasets.Dataset
            Dataset for sharing data. Not implemented yet! TODO
        log_dir : str
            Location to write shared_params (only writing for 2 procs per machine)
        alpha : float
            Percentage of model to share
        dict_ordered : bool
            Specifies if the python dict maintains the order of insertion
        save_shared : bool
            Specifies if the indices of shared parameters should be logged
        metadata_cap : float
            Share full model when self.alpha > metadata_cap
        wavelet: str
            name of the wavelet to be used in gradient compression
        level: int
            name of the wavelet to be used in gradient compression
        change_based_selection : bool
            use frequency change to select topk frequencies
        save_accumulated : bool
            True if accumulated weight change in the wavelet domain should be written to file. In case of accumulation
            the accumulated change is stored.
        accumulation : bool
            True if the the indices to share should be selected based on accumulated frequency change
        accumulate_averaging_changes: bool
            True if the accumulation should account the model change due to averaging
        """
        self.wavelet = wavelet
        self.level = level

        super().__init__(
            rank,
            machine_id,
            communication,
            mapping,
            graph,
            model,
            dataset,
            log_dir,
            lower_bound,
            metro_hastings,
            alpha = alpha,
            dict_ordered = dict_ordered,
            save_shared = save_shared,
            metadata_cap = metadata_cap,
            accumulation = accumulation,
            save_accumulated = save_accumulated,
            change_transformer = lambda x: change_transformer_wavelet(x, wavelet, level),
            accumulate_averaging_changes = accumulate_averaging_changes,
        )

        self.change_based_selection = change_based_selection

        # Do a dummy transform to get the shape and coefficents slices
        coeff = pywt.wavedec(self.init_model.numpy(), self.wavelet, level=self.level)
        data, coeff_slices = pywt.coeffs_to_array(coeff)
        self.wt_shape = data.shape
        self.coeff_slices = coeff_slices

    def apply_wavelet(self):
        """
        Does wavelet transformation of the model parameters and selects topK (alpha) of them in the frequency domain
        based on the undergone change during the current training step

        Returns
        -------
        tuple
            (a,b). a: selected wavelet coefficients, b: Their indices.

        """

        logging.info("Returning wavelet compressed model weights")
        data = self.pre_share_model_transformed
        if self.change_based_selection:
            diff = self.model.model_change
            _, index = torch.topk(
                diff.abs(),
                round(self.alpha * len(diff)),
                dim=0,
                sorted=False,
            )
        else:
            _, index = torch.topk(
                data.abs(),
                round(self.alpha * len(data)),
                dim=0,
                sorted=False,
            )
        index, _ = torch.sort(index)
        return data[index], index

    def extract_top_gradients(self):
        """
        Extract the indices and values of the topK gradients.
        The gradients must have been accumulated.

        Returns
        -------
        tuple
            (a,b). a: The magnitudes of the topK gradients, b: Their indices.

        """
        if self.lower_bound == 0.0:
            return self.apply_wavelet()

        data = self.pre_share_model_transformed
        if self.change_based_selection:
            diff = self.model.model_change
            _, index = torch.topk(
                diff.abs(),
                round(self.alpha * len(diff)),
                dim=0,
                sorted=False,
            )
        else:
            _, index = torch.topk(
                data.abs(),
                round(self.alpha * len(data)),
                dim=0,
                sorted=False,
            )
        ind, _ = torch.sort(index)

        if self.communication_round > self.start_lower_bounding_at:
            # because the superclass increases it where it is inconvenient for this subclass
            currently_shared = self.model.shared_parameters_counter.clone().detach()
            currently_shared[ind] += 1
            ind_small = (
                currently_shared < self.communication_round * self.lower_bound
            ).nonzero(as_tuple=True)[0]
            ind_small_unique = np.setdiff1d(
                ind_small.numpy(), ind.numpy(), assume_unique=True
            )
            take_max = round(self.lower_bound * self.alpha * data.shape[0])
            logging.info(
                "lower: %i %i %i", len(ind_small), len(ind_small_unique), take_max
            )
            if take_max > ind_small_unique.shape[0]:
                take_max = ind_small_unique.shape[0]
            to_take = torch.rand(ind_small_unique.shape[0])
            _, ind_of_to_take = torch.topk(to_take, take_max, dim=0, sorted=False)
            ind_bound = torch.from_numpy(ind_small_unique)[ind_of_to_take]
            logging.info("lower bounding: %i %i", len(ind), len(ind_bound))
            # val = torch.concat(val, G_topk[ind_bound]) # not really needed, as thes are abs values and not further used
            ind = torch.cat([ind, ind_bound])

        index, _ = torch.sort(ind)
        return _, index

    def serialized_model(self):
        """
        Convert model to json dict. self.alpha specifies the fraction of model to send.

        Returns
        -------
        dict
            Model converted to json dict

        """
        m = dict()
        if self.alpha >= self.metadata_cap:  # Share fully
            data = self.pre_share_model_transformed
            m["params"] = data.numpy()
            if self.model.accumulated_changes is not None:
                self.model.accumulated_changes = torch.zeros_like(
                    self.model.accumulated_changes
                )
            return m

        with torch.no_grad():
            topk, indices = self.apply_wavelet()
            self.model.shared_parameters_counter[indices] += 1
            self.model.rewind_accumulation(indices)
            if self.save_shared:
                shared_params = dict()
                shared_params["order"] = list(self.model.state_dict().keys())
                shapes = dict()
                for k, v in self.model.state_dict().items():
                    shapes[k] = list(v.shape)
                shared_params["shapes"] = shapes

                shared_params[self.communication_round] = indices.tolist()  # is slow

                shared_params["alpha"] = self.alpha

                with open(
                    os.path.join(
                        self.folder_path,
                        "{}_shared_params.json".format(self.communication_round + 1),
                    ),
                    "w",
                ) as of:
                    json.dump(shared_params, of)

            if not self.dict_ordered:
                raise NotImplementedError

            m["alpha"] = self.alpha

            m["params"] = topk.numpy()

            m["indices"] = indices.numpy().astype(np.int32)

            m["send_partial"] = True

            return m

    def deserialized_model_avg(self, m):
        """
        Convert received dict to state_dict.

        Parameters
        ----------
        m : dict
            dict received

        Returns
        -------
        state_dict
            state_dict of received

        """
        if "send_partial" not in m:
            return super().deserialized_model(m)

        with torch.no_grad():
            state_dict = self.model.state_dict()

            if not self.dict_ordered:
                raise NotImplementedError

            # could be made more efficent
            T = torch.zeros_like(self.init_model)
            index_tensor = torch.tensor(m["indices"], dtype=torch.long)
            logging.debug("Original tensor: {}".format(T[index_tensor]))
            T[index_tensor] = torch.tensor(m["params"])
            logging.debug("Final tensor: {}".format(T[index_tensor]))

            return T, index_tensor

    def deserialized_model(self, m):
        """
        Convert received dict to state_dict.

        Parameters
        ----------
        m : dict
            received dict

        Returns
        -------
        state_dict
            state_dict of received

        """
        ret = dict()
        if "send_partial" not in m:
            params = m["params"]
            params_tensor = torch.tensor(params)
            ret["params"] = params_tensor
            return ret

        with torch.no_grad():
            if not self.dict_ordered:
                raise NotImplementedError
            alpha = m["alpha"]

            params_tensor = torch.tensor(m["params"])
            indices_tensor = torch.tensor(m["indices"], dtype=torch.long)
            ret = dict()
            ret["indices"] = indices_tensor
            ret["params"] = params_tensor
            ret["send_partial"] = True
        return ret

    def _averaging(self):
        """
        Averages the received model with the local model

        """
        with torch.no_grad():
            total = None
            weight_total = 0
            wt_params = self.pre_share_model_transformed
            if not self.metro_hastings:
                weight_vector = torch.ones_like(wt_params)
                datas = []
            for i, n in enumerate(self.peer_deques):
                degree, iteration, data = self.peer_deques[n].popleft()
                logging.debug(
                    "Averaging model from neighbor {} of iteration {}".format(
                        n, iteration
                    )
                )
                data = self.deserialized_model(data)
                params = data["params"]
                if "indices" in data:
                    indices = data["indices"]
                    if not self.metro_hastings:
                        weight_vector[indices] += 1
                        topkwf = torch.zeros_like(wt_params)
                        topkwf[indices] = params
                        topkwf = topkwf.reshape(self.wt_shape)
                        datas.append(topkwf)
                    else:
                        # use local data to complement
                        topkwf = wt_params.clone().detach()
                        topkwf[indices] = params
                        topkwf = topkwf.reshape(self.wt_shape)

                else:
                    topkwf = params.reshape(self.wt_shape)
                    if not self.metro_hastings:
                        weight_vector += 1
                        datas.append(topkwf)

                if self.metro_hastings:
                    weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
                    weight_total += weight
                    if total is None:
                        total = weight * topkwf
                    else:
                        total += weight * topkwf
            if not self.metro_hastings:
                weight_vector = 1.0 / weight_vector
                # speed up by exploiting sparsity
                total = wt_params * weight_vector
                for d in datas:
                    total += d * weight_vector
            else:
                # Metro-Hastings
                total += (1 - weight_total) * wt_params

            avg_wf_params = pywt.array_to_coeffs(
                total.numpy(), self.coeff_slices, output_format="wavedec"
            )
            reverse_total = torch.from_numpy(
                pywt.waverec(avg_wf_params, wavelet=self.wavelet)
            )

            start_index = 0
            std_dict = {}
            for i, key in enumerate(self.model.state_dict()):
                end_index = start_index + self.lens[i]
                std_dict[key] = reverse_total[start_index:end_index].reshape(
                    self.shapes[i]
                )
                start_index = end_index

        self.model.load_state_dict(std_dict)