import json
import logging
import os
from pathlib import Path
from time import time

import numpy as np
import torch
import torch.fft as fft

from decentralizepy.sharing.PartialModel import PartialModel


def change_transformer_fft(x):
    """
    Transforms the model changes into frequency domain

    Parameters
    ----------
    x : torch.Tensor
        Model change in the space domain

    Returns
    -------
    x : torch.Tensor
        Representation of the change int the frequency domain
    """
    return fft.rfft(x)


class FFT(PartialModel):
    """
    This class implements the fft version of model sharing
    It is based on PartialModel.py

    """

    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        log_dir,
        alpha=1.0,
        dict_ordered=True,
        save_shared=False,
        metadata_cap=1.0,
        change_based_selection=True,
        save_accumulated="",
        accumulation=True,
        accumulate_averaging_changes=False,
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank
        machine_id : int
            Global machine id
        communication : decentralizepy.communication.Communication
            Communication module used to send and receive messages
        mapping : decentralizepy.mappings.Mapping
            Mapping (rank, machine_id) -> uid
        graph : decentralizepy.graphs.Graph
            Graph reprensenting neighbors
        model : decentralizepy.models.Model
            Model to train
        dataset : decentralizepy.datasets.Dataset
            Dataset for sharing data. Not implemented yet! TODO
        log_dir : str
            Location to write shared_params (only writing for 2 procs per machine)
        alpha : float
            Percentage of model to share
        dict_ordered : bool
            Specifies if the python dict maintains the order of insertion
        save_shared : bool
            Specifies if the indices of shared parameters should be logged
        metadata_cap : float
            Share full model when self.alpha > metadata_cap
        change_based_selection : bool
            use frequency change to select topk frequencies
        save_accumulated : bool
            True if accumulated weight change in the frequency domain should be written to file. In case of accumulation
            the accumulated change is stored.
        accumulation : bool
            True if the the indices to share should be selected based on accumulated frequency change
        accumulate_averaging_changes: bool
            True if the accumulation should account the model change due to averaging

        """
        super().__init__(
            rank,
            machine_id,
            communication,
            mapping,
            graph,
            model,
            dataset,
            log_dir,
            alpha,
            dict_ordered,
            save_shared,
            metadata_cap,
            accumulation,
            save_accumulated,
            change_transformer_fft,
            accumulate_averaging_changes,
        )
        self.change_based_selection = change_based_selection

    def apply_fft(self):
        """
        Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
        based on the undergone change during the current training step

        Returns
        -------
        tuple
            (a,b). a: selected fft frequencies (complex numbers), b: Their indices.

        """

        logging.info("Returning fft compressed model weights")
        with torch.no_grad():

            flat_fft = self.pre_share_model_transformed
            if self.change_based_selection:
                diff = self.model.model_change
                _, index = torch.topk(
                    diff.abs(), round(self.alpha * len(diff)), dim=0, sorted=False
                )
            else:
                _, index = torch.topk(
                    flat_fft.abs(),
                    round(self.alpha * len(flat_fft)),
                    dim=0,
                    sorted=False,
                )

        return flat_fft[index], index

    def serialized_model(self):
        """
        Convert model to json dict. self.alpha specifies the fraction of model to send.

        Returns
        -------
        dict
            Model converted to json dict

        """
        m = dict()
        if self.alpha >= self.metadata_cap:  # Share fully
            data = self.pre_share_model_transformed
            m["params"] = data.numpy()
            self.total_data += len(self.communication.encrypt(m["params"]))
            if self.model.accumulated_changes is not None:
                self.model.accumulated_changes = torch.zeros_like(
                    self.model.accumulated_changes
                )
            return m

        with torch.no_grad():
            topk, indices = self.apply_fft()
            self.model.shared_parameters_counter[indices] += 1
            self.model.rewind_accumulation(indices)

            if self.save_shared:
                shared_params = dict()
                shared_params["order"] = list(self.model.state_dict().keys())
                shapes = dict()
                for k, v in self.model.state_dict().items():
                    shapes[k] = list(v.shape)
                shared_params["shapes"] = shapes

                shared_params[self.communication_round] = indices.tolist()  # is slow

                shared_params["alpha"] = self.alpha

                with open(
                    os.path.join(
                        self.folder_path,
                        "{}_shared_params.json".format(self.communication_round + 1),
                    ),
                    "w",
                ) as of:
                    json.dump(shared_params, of)

            if not self.dict_ordered:
                raise NotImplementedError

            m["alpha"] = self.alpha
            m["params"] = topk.numpy()
            m["indices"] = indices.numpy().astype(np.int32)
            m["send_partial"] = True

            self.total_data += len(self.communication.encrypt(m["params"]))
            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
                self.communication.encrypt(m["alpha"])
            )

        return m

    def deserialized_model(self, m):
        """
        Convert received json dict to state_dict.

        Parameters
        ----------
        m : dict
            json dict received

        Returns
        -------
        state_dict
            state_dict of received

        """
        ret = dict()
        if "send_partial" not in m:
            params = m["params"]
            params_tensor = torch.tensor(params)
            ret["params"] = params_tensor

        with torch.no_grad():
            if not self.dict_ordered:
                raise NotImplementedError

            indices = m["indices"]
            alpha = m["alpha"]
            params = m["params"]

            params_tensor = torch.tensor(params)
            indices_tensor = torch.tensor(indices, dtype=torch.long)
            ret["indices"] = indices_tensor
            ret["params"] = params_tensor
            ret["send_partial"] = True
        return ret

    def _averaging(self):
        """
        Averages the received model with the local model

        """
        with torch.no_grad():
            total = None
            weight_total = 0
            tensors_to_cat = [
                v.data.flatten() for _, v in self.model.state_dict().items()
            ]
            pre_share_model = torch.cat(tensors_to_cat, dim=0)
            flat_fft = self.change_transformer(pre_share_model)

            for i, n in enumerate(self.peer_deques):
                degree, iteration, data = self.peer_deques[n].popleft()
                logging.debug(
                    "Averaging model from neighbor {} of iteration {}".format(
                        n, iteration
                    )
                )
                data = self.deserialized_model(data)
                params = data["params"]
                if "indices" in data:
                    indices = data["indices"]
                    # use local data to complement
                    topkf = flat_fft.clone().detach()
                    topkf[indices] = params
                else:
                    topkf = params

                weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
                weight_total += weight
                if total is None:
                    total = weight * topkf
                else:
                    total += weight * topkf

            # Metro-Hastings
            total += (1 - weight_total) * flat_fft
            reverse_total = fft.irfft(total)

            start_index = 0
            std_dict = {}
            for i, key in enumerate(self.model.state_dict()):
                end_index = start_index + self.lens[i]
                std_dict[key] = reverse_total[start_index:end_index].reshape(
                    self.shapes[i]
                )
                start_index = end_index

        self.model.load_state_dict(std_dict)