import base64
import json
import logging
import os
import pickle
from pathlib import Path
from time import time

import torch
import torch.fft as fft

from decentralizepy.sharing.Sharing import Sharing


class FFT(Sharing):
    """
    This class implements the fft version of model sharing
    It is based on PartialModel.py

    """

    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        log_dir,
        alpha=1.0,
        dict_ordered=True,
        save_shared=False,
        metadata_cap=1.0,
        pickle=True,
        change_based_selection=True,
        accumulation=True,
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank
        machine_id : int
            Global machine id
        communication : decentralizepy.communication.Communication
            Communication module used to send and receive messages
        mapping : decentralizepy.mappings.Mapping
            Mapping (rank, machine_id) -> uid
        graph : decentralizepy.graphs.Graph
            Graph reprensenting neighbors
        model : decentralizepy.models.Model
            Model to train
        dataset : decentralizepy.datasets.Dataset
            Dataset for sharing data. Not implemented yet! TODO
        log_dir : str
            Location to write shared_params (only writing for 2 procs per machine)
        alpha : float
            Percentage of model to share
        dict_ordered : bool
            Specifies if the python dict maintains the order of insertion
        save_shared : bool
            Specifies if the indices of shared parameters should be logged
        metadata_cap : float
            Share full model when self.alpha > metadata_cap
        pickle : bool
            use pickle to serialize the model parameters
        change_based_selection : bool
            use frequency change to select topk frequencies
        accumulation : bool
            True if the the indices to share should be selected based on accumulated frequency change
        """
        super().__init__(
            rank, machine_id, communication, mapping, graph, model, dataset, log_dir
        )
        self.alpha = alpha
        self.dict_ordered = dict_ordered
        self.save_shared = save_shared
        self.metadata_cap = metadata_cap
        self.total_meta = 0

        self.pickle = pickle

        logging.info("subsampling pickling=" + str(pickle))

        if self.save_shared:
            # Only save for 2 procs: Save space
            if rank != 0 or rank != 1:
                self.save_shared = False

        if self.save_shared:
            self.folder_path = os.path.join(
                self.log_dir, "shared_params/{}".format(self.rank)
            )
            Path(self.folder_path).mkdir(parents=True, exist_ok=True)

        self.change_based_selection = change_based_selection
        self.accumulation = accumulation

    def apply_fft(self):
        """
        Does fft transformation of the model parameters and selects topK (alpha) of them in the frequency domain
        based on the undergone change during the current training step

        Returns
        -------
        tuple
            (a,b). a: selected fft frequencies (complex numbers), b: Their indices.

        """

        logging.info("Returning fft compressed model weights")
        tensors_to_cat = [v.data.flatten() for _, v in self.model.state_dict().items()]
        concated = torch.cat(tensors_to_cat, dim=0)

        if self.change_based_selection:
            flat_fft = fft.rfft(concated)
            if self.accumulation:
                logging.info(
                    "fft topk extract frequencies based on accumulated model frequency change"
                )
                diff = self.model.accumulated_frequency + (flat_fft - self.model.prev)
            else:
                diff = flat_fft - self.model.accumulated_frequency
            _, index = torch.topk(
                diff.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
            )
        else:
            flat_fft = fft.rfft(concated)
            _, index = torch.topk(
                flat_fft.abs(), round(self.alpha * len(flat_fft)), dim=0, sorted=False
            )

        if self.accumulation:
            self.model.accumulated_frequency[index] = 0.0
        return flat_fft[index], index

    def serialized_model(self):
        """
        Convert model to json dict. self.alpha specifies the fraction of model to send.

        Returns
        -------
        dict
            Model converted to json dict

        """
        if self.alpha > self.metadata_cap:  # Share fully
            return super().serialized_model()

        with torch.no_grad():
            topk, indices = self.apply_fft()

            if self.save_shared:
                shared_params = dict()
                shared_params["order"] = list(self.model.state_dict().keys())
                shapes = dict()
                for k, v in self.model.state_dict().items():
                    shapes[k] = list(v.shape)
                shared_params["shapes"] = shapes

                shared_params[self.communication_round] = indices.tolist()  # is slow

                shared_params["alpha"] = self.alpha

                with open(
                    os.path.join(
                        self.folder_path,
                        "{}_shared_params.json".format(self.communication_round + 1),
                    ),
                    "w",
                ) as of:
                    json.dump(shared_params, of)

            m = dict()

            if not self.dict_ordered:
                raise NotImplementedError

            m["alpha"] = self.alpha
            m["params"] = topk.numpy()
            m["indices"] = indices.numpy()

            self.total_data += len(self.communication.encrypt(m["params"]))
            self.total_meta += len(self.communication.encrypt(m["indices"])) + len(
                self.communication.encrypt(m["alpha"])
            )

            return m

    def deserialized_model(self, m):
        """
        Convert received json dict to state_dict.

        Parameters
        ----------
        m : dict
            json dict received

        Returns
        -------
        state_dict
            state_dict of received

        """
        if self.alpha > self.metadata_cap:  # Share fully
            return super().deserialized_model(m)

        with torch.no_grad():
            state_dict = self.model.state_dict()

            if not self.dict_ordered:
                raise NotImplementedError

            shapes = []
            lens = []
            tensors_to_cat = []
            for _, v in state_dict.items():
                shapes.append(v.shape)
                t = v.flatten()
                lens.append(t.shape[0])
                tensors_to_cat.append(t)

            T = torch.cat(tensors_to_cat, dim=0)

            indices = m["indices"]
            alpha = m["alpha"]
            params = m["params"]

            params_tensor = torch.tensor(params)
            indices_tensor = torch.tensor(indices)
            ret = dict()
            ret["indices"] = indices_tensor
            ret["params"] = params_tensor
            return ret

    def step(self):
        """
        Perform a sharing step. Implements D-PSGD.

        """
        t_start = time()
        data = self.serialized_model()
        t_post_serialize = time()
        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
        all_neighbors = self.graph.neighbors(my_uid)
        iter_neighbors = self.get_neighbors(all_neighbors)
        data["degree"] = len(all_neighbors)
        data["iteration"] = self.communication_round
        for neighbor in iter_neighbors:
            self.communication.send(neighbor, data)
        t_post_send = time()
        logging.info("Waiting for messages from neighbors")
        while not self.received_from_all():
            sender, data = self.communication.receive()
            logging.debug("Received model from {}".format(sender))
            degree = data["degree"]
            iteration = data["iteration"]
            del data["degree"]
            del data["iteration"]
            self.peer_deques[sender].append((degree, iteration, data))
            logging.info(
                "Deserialized received model from {} of iteration {}".format(
                    sender, iteration
                )
            )
        t_post_recv = time()

        logging.info("Starting model averaging after receiving from all neighbors")
        total = None
        weight_total = 0

        # FFT of this model
        shapes = []
        lens = []
        tensors_to_cat = []
        for _, v in self.model.state_dict().items():
            shapes.append(v.shape)
            t = v.flatten()
            lens.append(t.shape[0])
            tensors_to_cat.append(t)
        concated = torch.cat(tensors_to_cat, dim=0)
        flat_fft = fft.rfft(concated)

        for i, n in enumerate(self.peer_deques):
            degree, iteration, data = self.peer_deques[n].popleft()
            logging.debug(
                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
            )
            data = self.deserialized_model(data)
            params = data["params"]
            indices = data["indices"]
            # use local data to complement
            topkf = flat_fft.clone().detach()
            topkf[indices] = params

            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
            weight_total += weight
            if total is None:
                total = weight * topkf
            else:
                total += weight * topkf

        # Metro-Hastings
        total += (1 - weight_total) * flat_fft
        reverse_total = fft.irfft(total)

        start_index = 0
        std_dict = {}
        for i, key in enumerate(self.model.state_dict()):
            end_index = start_index + lens[i]
            std_dict[key] = reverse_total[start_index:end_index].reshape(shapes[i])
            start_index = end_index

        self.model.load_state_dict(std_dict)

        logging.info("Model averaging complete")

        self.communication_round += 1

        t_end = time()

        logging.info(
            "Sharing::step | Serialize: %f; Send: %f; Recv: %f; Averaging: %f; Total: %f",
            t_post_serialize - t_start,
            t_post_send - t_post_serialize,
            t_post_recv - t_post_send,
            t_end - t_post_recv,
            t_end - t_start,
        )