Skip to content
Snippets Groups Projects
PartialModel.py 1.83 KiB
Newer Older
import json
import math

import numpy
import torch

Rishi Sharma's avatar
Rishi Sharma committed
from decentralizepy.sharing.Sharing import Sharing

Rishi Sharma's avatar
Rishi Sharma committed
class PartialModel(Sharing):
    def __init__(
        self, rank, machine_id, communication, mapping, graph, model, dataset, alpha=1.0
    ):
        super().__init__(
            rank, machine_id, communication, mapping, graph, model, dataset
        )
        self.alpha = alpha

    def extract_sorted_gradients(self):
        assert len(self.model.accumulated_gradients) > 0
        gradient_sum = self.model.accumulated_gradients[0]
        for i in range(1, len(self.model.accumulated_gradients)):
            for key in self.model.accumulated_gradients[i]:
                gradient_sum[key] += self.model.accumulated_gradients[i][key]
        gradient_sequence = []

        for key, gradient in gradient_sum.items():
            for index, val in enumerate(torch.flatten(gradient)):
                gradient_sequence.append((val, key, index))

        gradient_sequence.sort()
        return gradient_sequence

    def serialized_model(self):
        gradient_sequence = self.extract_sorted_gradients()
        gradient_sequence = gradient_sequence[
            : math.round(len(gradient_sequence) * self.alpha)
        ]

        m = dict()
        for _, key, index in gradient_sequence:
            if key not in m:
                m[key] = []
            m[key].append(index, torch.flatten(self.model.state_dict()[key])[index])

        for key in m:
            m[key] = json.dumps(m[key])

        return m

    def deserialized_model(self, m):
        state_dict = self.model.state_dict()

        for key, value in m.items():
            for index, param_val in json.loads(value):
                torch.flatten(state_dict[key])[index] = param_val
            state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
        return state_dict