import logging
import pickle
from collections import deque

import numpy
import torch


class Sharing:
    """
    API defining who to share with and what, and what to do on receiving

    """

    def __init__(
        self, rank, machine_id, communication, mapping, graph, model, dataset, log_dir
    ):
        """
        Constructor

        Parameters
        ----------
        rank : int
            Local rank
        machine_id : int
            Global machine id
        communication : decentralizepy.communication.Communication
            Communication module used to send and receive messages
        mapping : decentralizepy.mappings.Mapping
            Mapping (rank, machine_id) -> uid
        graph : decentralizepy.graphs.Graph
            Graph reprensenting neighbors
        model : decentralizepy.models.Model
            Model to train
        dataset : decentralizepy.datasets.Dataset
            Dataset for sharing data. Not implemented yer! TODO
        log_dir : str
            Location to write shared_params (only writing for 2 procs per machine)

        """
        self.rank = rank
        self.machine_id = machine_id
        self.uid = mapping.get_uid(rank, machine_id)
        self.communication = communication
        self.mapping = mapping
        self.graph = graph
        self.model = model
        self.dataset = dataset
        self.communication_round = 0
        self.log_dir = log_dir
        self.total_data = 0

        self.peer_deques = dict()
        my_neighbors = self.graph.neighbors(self.uid)
        for n in my_neighbors:
            self.peer_deques[n] = deque()

    def received_from_all(self):
        """
        Check if all neighbors have sent the current iteration

        Returns
        -------
        bool
            True if required data has been received, False otherwise

        """
        for _, i in self.peer_deques.items():
            if len(i) == 0:
                return False
        return True

    def get_neighbors(self, neighbors):
        """
        Choose which neighbors to share with

        Parameters
        ----------
        neighbors : list(int)
            List of all neighbors

        Returns
        -------
        list(int)
            Neighbors to share with

        """
        # modify neighbors here
        return neighbors

    def serialized_model(self):
        """
        Convert model to a dictionary. Here we can choose how much to share

        Returns
        -------
        dict
            Model converted to dict

        """
        m = dict()
        for key, val in self.model.state_dict().items():
            m[key] = val.numpy()
            self.total_data += len(self.communication.encrypt(m[key]))
        return m

    def deserialized_model(self, m):
        """
        Convert received dict to state_dict.

        Parameters
        ----------
        m : dict
            received dict

        Returns
        -------
        state_dict
            state_dict of received

        """
        state_dict = dict()
        for key, value in m.items():
            state_dict[key] = torch.from_numpy(value)
        return state_dict

    def step(self):
        """
        Perform a sharing step. Implements D-PSGD.

        """
        data = self.serialized_model()
        my_uid = self.mapping.get_uid(self.rank, self.machine_id)
        all_neighbors = self.graph.neighbors(my_uid)
        iter_neighbors = self.get_neighbors(all_neighbors)
        data["degree"] = len(all_neighbors)
        data["iteration"] = self.communication_round
        for neighbor in iter_neighbors:
            self.communication.send(neighbor, data)

        logging.info("Waiting for messages from neighbors")
        while not self.received_from_all():
            sender, data = self.communication.receive()
            logging.debug("Received model from {}".format(sender))
            degree = data["degree"]
            iteration = data["iteration"]
            del data["degree"]
            del data["iteration"]
            self.peer_deques[sender].append((degree, iteration, data))
            logging.info(
                "Deserialized received model from {} of iteration {}".format(
                    sender, iteration
                )
            )

        logging.info("Starting model averaging after receiving from all neighbors")
        total = dict()
        weight_total = 0
        for i, n in enumerate(self.peer_deques):
            degree, iteration, data = self.peer_deques[n].popleft()
            logging.debug(
                "Averaging model from neighbor {} of iteration {}".format(n, iteration)
            )
            data = self.deserialized_model(data)
            weight = 1 / (max(len(self.peer_deques), degree) + 1)  # Metro-Hastings
            weight_total += weight
            for key, value in data.items():
                if key in total:
                    total[key] += value * weight
                else:
                    total[key] = value * weight

        for key, value in self.model.state_dict().items():
            total[key] += (1 - weight_total) * value  # Metro-Hastings

        self.model.load_state_dict(total)

        logging.info("Model averaging complete")

        self.communication_round += 1