Skip to content
Snippets Groups Projects
Commit 2b7c3661 authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Default value for random_seed (backward compatibility) and indices to int32

parent e8bfbe5b
No related branches found
No related tags found
No related merge requests found
......@@ -124,7 +124,10 @@ class Node:
"""
dataset_module = importlib.import_module(dataset_configs["dataset_package"])
self.dataset_class = getattr(dataset_module, dataset_configs["dataset_class"])
torch.manual_seed(dataset_configs["random_seed"])
random_seed = (
dataset_configs["random_seed"] if "random_seed" in dataset_configs else 97
)
torch.manual_seed(random_seed)
self.dataset_params = utils.remove_keys(
dataset_configs,
["dataset_package", "dataset_class", "model_class", "random_seed"],
......
......@@ -3,6 +3,7 @@ import logging
import os
from pathlib import Path
import numpy as np
import torch
from decentralizepy.sharing.Sharing import Sharing
......@@ -155,7 +156,7 @@ class PartialModel(Sharing):
if not self.dict_ordered:
raise NotImplementedError
m["indices"] = G_topk.numpy()
m["indices"] = G_topk.numpy().astype(np.int32)
m["params"] = T_topk.numpy()
......@@ -206,7 +207,7 @@ class PartialModel(Sharing):
tensors_to_cat.append(t)
T = torch.cat(tensors_to_cat, dim=0)
index_tensor = torch.tensor(m["indices"])
index_tensor = torch.tensor(m["indices"], dtype=torch.long)
logging.debug("Original tensor: {}".format(T[index_tensor]))
T[index_tensor] = torch.tensor(m["params"])
logging.debug("Final tensor: {}".format(T[index_tensor]))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment