import json import math import numpy import torch from decentralizepy.sharing.Sharing import Sharing class PartialModel(Sharing): def __init__( self, rank, machine_id, communication, mapping, graph, model, dataset, alpha=1.0 ): super().__init__( rank, machine_id, communication, mapping, graph, model, dataset ) self.alpha = alpha def extract_sorted_gradients(self): assert len(self.model.accumulated_gradients) > 0 gradient_sum = self.model.accumulated_gradients[0] for i in range(1, len(self.model.accumulated_gradients)): for key in self.model.accumulated_gradients[i]: gradient_sum[key] += self.model.accumulated_gradients[i][key] gradient_sequence = [] for key, gradient in gradient_sum.items(): for index, val in enumerate(torch.flatten(gradient)): gradient_sequence.append((val, key, index)) gradient_sequence.sort() return gradient_sequence def serialized_model(self): gradient_sequence = self.extract_sorted_gradients() gradient_sequence = gradient_sequence[ : math.round(len(gradient_sequence) * self.alpha) ] m = dict() for _, key, index in gradient_sequence: if key not in m: m[key] = [] m[key].append(index, torch.flatten(self.model.state_dict()[key])[index]) for key in m: m[key] = json.dumps(m[key]) return m def deserialized_model(self, m): state_dict = self.model.state_dict() for key, value in m.items(): for index, param_val in json.loads(value): torch.flatten(state_dict[key])[index] = param_val state_dict[key] = torch.from_numpy(numpy.array(json.loads(value))) return state_dict