import logging

from decentralizepy.sharing.PartialModel import PartialModel


class GrowingAlpha(PartialModel):
    def __init__(
        self,
        rank,
        machine_id,
        communication,
        mapping,
        graph,
        model,
        dataset,
        log_dir,
        init_alpha=0.0,
        max_alpha=1.0,
        k=10,
        dict_ordered=True,
        save_shared=False,
        metadata_cap=1.0,
    ):
        super().__init__(
            rank,
            machine_id,
            communication,
            mapping,
            graph,
            model,
            dataset,
            log_dir,
            init_alpha,
            dict_ordered,
            save_shared,
            metadata_cap,
        )
        self.init_alpha = init_alpha
        self.max_alpha = max_alpha
        self.k = k

    def step(self):
        if (self.communication_round + 1) % self.k == 0:
            self.alpha += (self.max_alpha - self.init_alpha) / self.k
            self.alpha = min(self.alpha, 1.00)

        if self.alpha == 0.0:
            logging.info("Not sending/receiving data (alpha=0.0)")
            self.communication_round += 1
            return

        super().step()