Skip to content
Snippets Groups Projects
Commit 9d4a7b6e authored by Rishi Sharma's avatar Rishi Sharma
Browse files

Partial model sharing fix

parent 15af7b00
No related branches found
No related tags found
No related merge requests found
4
0 1
0 2
0 3
1 0
1 2
1 3
2 0
2 1
2 3
3 0
3 1
3 2
......@@ -2,7 +2,7 @@
dataset_package = decentralizepy.datasets.Femnist
dataset_class = Femnist
model_class = CNN
n_procs = 16
n_procs = 4
train_dir = /home/risharma/Gitlab/decentralizepy/leaf/data/femnist/per_user_data/train
test_dir = /home/risharma/Gitlab/decentralizepy/leaf/data/femnist/data/test
; python list of fractions below
......@@ -14,9 +14,9 @@ optimizer_class = Adam
lr = 0.01
[TRAIN_PARAMS]
training_package = decentralizepy.training.Training
training_class = Training
epochs_per_round = 5
training_package = decentralizepy.training.GradientAccumulator
training_class = GradientAccumulator
epochs_per_round = 2
batch_size = 1024
shuffle = True
loss_package = torch.nn
......@@ -28,5 +28,6 @@ comm_class = TCP
addresses_filepath = ip_addr_6Machines.json
[SHARING]
sharing_package = decentralizepy.sharing.Sharing
sharing_class = Sharing
\ No newline at end of file
sharing_package = decentralizepy.sharing.PartialModel
sharing_class = PartialModel
alpha = 0.5
\ No newline at end of file
%% Cell type:code id: tags:
```
from datasets.Femnist import Femnist
from graphs import SmallWorld
from collections import defaultdict
import os
import json
import numpy as np
```
%% Cell type:code id: tags:
```
a = FEMNIST
a
```
%% Cell type:code id: tags:
```
b = SmallWorld(6, 2, 2, 1)
```
%% Cell type:code id: tags:
```
b.adj_list
```
%% Cell type:code id: tags:
```
for i in range(12):
print(b.neighbors(i))
```
%% Cell type:code id: tags:
```
clients = []
```
%% Cell type:code id: tags:
```
num_samples = []
data = defaultdict(lambda : None)
```
%% Cell type:code id: tags:
```
datadir = "./leaf/data/femnist/data/train"
files = os.listdir(datadir)
total_users=0
users = set()
```
%% Cell type:code id: tags:
```
files = os.listdir(datadir)[0:1]
```
%% Cell type:code id: tags:
```
for f in files:
file_path = os.path.join(datadir, f)
print(file_path)
with open(file_path, 'r') as inf:
client_data = json.load(inf)
current_users = len(client_data['users'])
print("Current_Users: ", current_users)
total_users += current_users
users.update(client_data['users'])
print("total_users: ", total_users)
print("total_users: ", len(users))
print(client_data['user_data'].keys())
print(np.array(client_data['user_data']['f3408_47']['x']).shape)
print(np.array(client_data['user_data']['f3408_47']['y']).shape)
print(np.array(client_data['user_data']['f3327_11']['x']).shape)
print(np.array(client_data['user_data']['f3327_11']['y']).shape)
print(np.unique(np.array(client_data['user_data']['f3327_11']['y'])))
```
%% Cell type:code id: tags:
```
file = 'run.py'
with open(file, 'r') as inf:
print(inf.readline().strip())
print(inf.readlines())
```
%% Cell type:code id: tags:
```
def f(l):
l[2] = 'c'
a = ['a', 'a', 'a']
print(a)
f(a)
print(a)
```
%% Cell type:code id: tags:
```
l = ['a', 'b', 'c']
print(l[:-1])
```
%% Cell type:code id: tags:
```
from localconfig import LocalConfig
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items('DATASET')))
return config
config = read_ini("config.ini")
for section in config:
print(section)
#d = dict(config.sections())
```
%% Cell type:code id: tags:
```
def func(a = 1, b = 2, c = 3):
print(a + b + c)
l = [3, 5, 7]
func(*l)
```
%% Cell type:code id: tags:
```
from torch import multiprocessing as mp
mp.spawn(fn = func, nprocs = 2, args = [], kwargs = {'a': 4, 'b': 5, 'c': 6})
```
%% Cell type:code id: tags:
```
l = '[0.4, 0.2, 0.3, 0.1]'
type(eval(l))
```
%% Cell type:code id: tags:
```
from decentralizepy.datasets.Femnist import Femnist
f1 = Femnist(0, 1, 'leaf/data/femnist/data/train')
ts = f1.get_trainset(1)
for data, target in ts:
print(data)
break
```
%% Cell type:code id: tags:
```
from decentralizepy.datasets.Femnist import Femnist
from decentralizepy.graphs.SmallWorld import SmallWorld
from decentralizepy.mappings.Linear import Linear
f = Femnist(2, 'leaf/data/femnist/data/train', sizes=[0.6, 0.4])
g = SmallWorld(4, 1, 0.5)
l = Linear(2, 2)
```
%% Cell type:code id: tags:
```
from decentralizepy.node.Node import Node
from torch import multiprocessing as mp
import logging
n1 = Node(0, l, g, f, "./results", logging.DEBUG)
n2 = Node(1, l, g, f, "./results", logging.DEBUG)
# mp.spawn(fn = Node, nprocs = 2, args=[l,g,f])
```
%% Cell type:code id: tags:
```
from testing import f
```
%% Cell type:code id: tags:
```
from torch import multiprocessing as mp
import torch
m1 = torch.nn.Linear(1,1)
o1 = torch.optim.SGD(m1.parameters(), 0.6)
print(m1)
mp.spawn(fn = f, nprocs = 2, args=[m1, o1])
```
%% Cell type:markdown id: tags:
%% Cell type:code id: tags:
```
o1.param_groups
```
%% Cell type:code id: tags:
```
with torch.no_grad():
o1.param_groups[0]["params"][0].copy_(torch.zeros(1,))
```
%% Cell type:code id: tags:
```
o1.param_groups
```
%% Cell type:code id: tags:
```
m1.state_dict()
```
%% Cell type:code id: tags:
```
import torch
loss = getattr(torch.nn.functional, 'nll_loss')
```
%% Cell type:code id: tags:
```
loss
```
%% Cell type:code id: tags:
```
%matplotlib inline
from decentralizepy.node.Node import Node
from decentralizepy.graphs.SmallWorld import SmallWorld
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from torch import multiprocessing as mp
import torch
import logging
from localconfig import LocalConfig
def read_ini(file_path):
config = LocalConfig(file_path)
for section in config:
print("Section: ", section)
for key, value in config.items(section):
print((key, value))
print(dict(config.items('DATASET')))
return config
config = read_ini("config.ini")
my_config = dict()
for section in config:
my_config[section] = dict(config.items(section))
#f = Femnist(2, 'leaf/data/femnist/data/train', sizes=[0.6, 0.4])
g = Graph()
g.read_graph_from_file("36_nodes.edges", "edges")
l = Linear(1, 36)
#Node(0, 0, l, g, my_config, 20, "results", logging.DEBUG)
mp.spawn(fn = Node, nprocs = g.n_procs, args=[0,l,g,my_config,20,"results",logging.INFO])
# mp.spawn(fn = Node, args = [l, g, config, 10, "results", logging.DEBUG], nprocs=2)
```
%% Output
Section: GRAPH
('package', 'decentralizepy.graphs.SmallWorld')
('graph_class', 'SmallWorld')
Section: DATASET
('dataset_package', 'decentralizepy.datasets.Femnist')
('dataset_class', 'Femnist')
('model_class', 'CNN')
('n_procs', 36)
('train_dir', 'leaf/data/femnist/per_user_data/train')
('test_dir', 'leaf/data/femnist/data/test')
('sizes', '')
Section: OPTIMIZER_PARAMS
('optimizer_package', 'torch.optim')
('optimizer_class', 'Adam')
('lr', 0.01)
Section: TRAIN_PARAMS
('training_package', 'decentralizepy.training.Training')
('training_class', 'Training')
('epochs_per_round', 1)
('batch_size', 1024)
('shuffle', True)
('loss_package', 'torch.nn')
('loss_class', 'CrossEntropyLoss')
Section: COMMUNICATION
('comm_package', 'decentralizepy.communication.TCP')
('comm_class', 'TCP')
('addresses_filepath', 'ip_addr.json')
Section: SHARING
('sharing_package', 'decentralizepy.sharing.Sharing')
('sharing_class', 'Sharing')
{'dataset_package': 'decentralizepy.datasets.Femnist', 'dataset_class': 'Femnist', 'model_class': 'CNN', 'n_procs': 36, 'train_dir': 'leaf/data/femnist/per_user_data/train', 'test_dir': 'leaf/data/femnist/data/test', 'sizes': ''}
%% Cell type:code id: tags:
```
```
%% Cell type:code id: tags:
```
from decentralizepy.mappings.Linear import Linear
from testing import f
from torch import multiprocessing as mp
l = Linear(1, 2)
mp.spawn(fn = f, nprocs = 2, args = [0, 2, "ip_addr.json", l])
```
%% Cell type:code id: tags:
```
from decentralizepy.datasets.Femnist import Femnist
f = Femnist()
f.file_per_user('../leaf/data/femnist/data/train','../leaf/data/femnist/per_user_data/train')
```
%% Cell type:code id: tags:
```
a = set()
a.update([2, 3, 4, 5])
```
%% Cell type:code id: tags:
```
a
```
%% Output
{2, 3, 4, 5}
%% Cell type:code id: tags:
```
print(*a)
```
%% Output
2 3 4 5
%% Cell type:code id: tags:
```
from decentralizepy.graphs.FullyConnected import FullyConnected
s = FullyConnected(4)
s.write_graph_to_file('4_node_fullyConnected.edges')
```
%% Cell type:code id: tags:
```
from decentralizepy.graphs.SmallWorld import SmallWorld
s = SmallWorld(96, 2, .5)
s.write_graph_to_file('96_nodes.edges')
```
%% Cell type:code id: tags:
```
import sys
sys.argv
```
%% Output
['/home/risharma/miniconda3/envs/decpy/lib/python3.9/site-packages/ipykernel_launcher.py',
'--ip=127.0.0.1',
'--stdin=9008',
'--control=9006',
'--hb=9005',
'--Session.signature_scheme="hmac-sha256"',
'--Session.key=b"eac5d2f8-c460-45f1-a268-1e4b46a6efd6"',
'--shell=9007',
'--transport="tcp"',
'--iopub=9009',
'--f=/tmp/tmp-21212479paJaUBJBN84.json']
%% Cell type:code id: tags:
```
import torch
from decentralizepy.datasets.Femnist import CNN
m1 = CNN()
o1 = torch.optim.SGD(m1.parameters(), 0.6)
print("m1_parameters: ", {k:v.data for k, v in zip(m1.state_dict(), m1.parameters())})
#print("m1_state_dict: ", m1.state_dict())
#print("o1_state_dict: ", o1.state_dict())
```
%% Output
m1_parameters: {'conv1.weight': tensor([[[[-0.0294, -0.0659, -0.1267, -0.1610, -0.1462],
[-0.1133, 0.1424, 0.1439, -0.0677, 0.1200],
[-0.0530, 0.1776, 0.1445, 0.0141, -0.1778],
[ 0.1847, 0.1211, 0.0345, 0.1131, 0.0740],
[ 0.1136, -0.0253, -0.1799, 0.1309, -0.0115]]],
[[[-0.0018, 0.0063, 0.1890, 0.0434, 0.1667],
[ 0.0428, 0.0402, 0.0275, 0.0432, 0.0524],
[ 0.1926, 0.0992, -0.0925, -0.1610, -0.1305],
[-0.1714, 0.0548, -0.1500, 0.1679, -0.0767],
[ 0.0094, 0.0525, 0.0273, -0.0462, -0.1696]]],
[[[-0.0033, -0.1317, -0.1695, 0.0221, 0.1301],
[ 0.0789, 0.0397, -0.1292, 0.0642, 0.0683],
[ 0.1198, 0.1179, 0.0186, -0.1519, -0.1354],
[ 0.1598, -0.1400, 0.0738, -0.1385, -0.1195],
[-0.0458, 0.1540, 0.0317, 0.1788, -0.0782]]],
[[[ 0.1726, -0.0152, -0.1036, -0.1624, 0.0804],
[ 0.1492, -0.0097, -0.1417, -0.1234, 0.1393],
[ 0.1567, -0.0843, 0.1540, -0.0776, -0.1575],
[-0.1371, -0.1458, 0.0518, 0.0329, 0.0664],
[ 0.1976, -0.0956, 0.0785, -0.1633, -0.1674]]],
[[[ 0.1716, -0.1734, -0.0078, 0.1092, -0.1912],
[ 0.1574, -0.1333, -0.0146, 0.0534, -0.1207],
[-0.1233, 0.0647, 0.0755, 0.1025, 0.0893],
[ 0.1158, -0.0836, -0.0749, 0.0244, -0.1928],
[-0.1903, -0.0820, 0.0529, 0.0907, 0.0951]]],
[[[-0.1117, -0.1858, -0.0375, -0.0397, 0.0412],
[ 0.0176, 0.1707, -0.0648, -0.0327, -0.1994],
[ 0.1754, -0.0413, -0.1533, -0.0537, 0.1160],
[ 0.1145, 0.0755, -0.0242, 0.0387, 0.1763],
[-0.0203, -0.1236, 0.0372, -0.0074, 0.0423]]],
[[[ 0.0021, 0.1159, 0.0160, 0.1778, -0.1911],
[ 0.0584, -0.0669, 0.1317, -0.1835, 0.0421],
[-0.1418, -0.1790, 0.1748, -0.0804, 0.0495],
[ 0.0498, -0.0983, 0.0343, 0.1142, -0.0266],
[ 0.0275, 0.0369, 0.1941, 0.1219, -0.0856]]],
[[[ 0.0191, -0.1786, -0.1492, -0.1155, 0.0068],
[-0.1917, 0.0696, -0.0657, 0.0317, -0.1831],
[-0.1451, -0.0804, 0.0266, -0.1905, 0.1968],
[-0.0370, 0.0943, 0.0535, -0.0174, 0.0636],
[ 0.1038, 0.0251, 0.0469, 0.0042, 0.1237]]],
[[[ 0.1488, 0.0967, 0.1249, 0.0799, -0.1052],
[-0.0188, -0.1058, -0.1708, -0.0282, 0.0980],
[-0.1925, 0.1517, 0.1029, -0.0329, -0.0788],
[ 0.0181, -0.1959, 0.1086, 0.0699, -0.0163],
[-0.0231, -0.0422, 0.1669, 0.1058, -0.1054]]],
[[[-0.1371, 0.0173, 0.0834, -0.1819, -0.1893],
[-0.1381, 0.1158, 0.0361, -0.1698, 0.1006],
[ 0.0385, -0.0467, 0.0824, -0.0762, -0.1883],
[ 0.1954, 0.1819, -0.0178, 0.1994, 0.1103],
[ 0.1728, -0.1546, 0.1065, -0.0425, 0.1576]]],
[[[-0.1496, -0.1969, -0.1935, 0.1659, 0.1540],
[-0.0021, -0.1386, -0.0502, -0.0283, 0.1867],
[ 0.1308, 0.1979, -0.0033, -0.1406, 0.0059],
[-0.1429, -0.0778, -0.0568, -0.1306, -0.0380],
[ 0.0740, -0.1416, -0.1770, 0.0070, -0.0831]]],
[[[ 0.0831, -0.0220, 0.0272, 0.0323, 0.1351],
[-0.1093, -0.1390, 0.0110, 0.0742, -0.0337],
[-0.0964, 0.1865, 0.1744, 0.0147, -0.0739],
[ 0.1436, -0.0275, 0.0454, 0.1006, -0.1111],
[ 0.1203, -0.0290, -0.0833, -0.0533, -0.1128]]],
[[[-0.0833, -0.1054, 0.1966, 0.1635, 0.0234],
[-0.1707, -0.0597, -0.0804, 0.1331, 0.1496],
[-0.1735, 0.1877, 0.1007, -0.1540, 0.1272],
[ 0.0859, 0.0146, -0.0869, 0.1998, 0.1968],
[ 0.1279, 0.0157, 0.0449, -0.0243, -0.1675]]],
[[[ 0.1474, 0.0966, -0.1795, -0.1186, -0.1994],
[ 0.1110, -0.1770, 0.1661, 0.1203, -0.1071],
[-0.0106, 0.0099, 0.0663, 0.0750, 0.0097],
[-0.0487, -0.0870, -0.0138, 0.0861, 0.0155],
[ 0.0640, 0.0427, 0.0353, 0.0669, -0.1381]]],
[[[ 0.0055, 0.1545, -0.0684, 0.1604, 0.0686],
[-0.1543, 0.0619, -0.1355, -0.0981, 0.1642],
[ 0.0907, 0.1712, 0.1466, -0.0799, -0.0670],
[ 0.0948, -0.1994, 0.0710, -0.0344, -0.1059],
[ 0.0230, -0.0667, 0.1027, 0.1431, -0.1946]]],
[[[-0.1192, -0.0009, -0.0426, 0.1015, 0.1718],
[-0.0381, 0.1195, -0.1853, -0.0561, -0.0592],
[-0.0719, 0.0546, 0.0530, 0.0038, -0.1063],
[ 0.1732, 0.1704, 0.1143, -0.1093, -0.0474],
[ 0.1588, 0.0620, 0.1074, 0.0200, 0.1258]]],
[[[-0.0276, -0.0863, -0.0745, 0.1831, 0.0149],
[ 0.0925, 0.1106, -0.1559, 0.0043, 0.1982],
[ 0.0403, 0.1336, -0.1262, 0.0649, 0.1244],
[ 0.1716, -0.0994, -0.1754, 0.1991, 0.1502],
[-0.1226, -0.0286, 0.1742, -0.0364, 0.1818]]],
[[[ 0.1905, 0.0292, -0.1152, 0.0816, -0.1832],
[-0.0553, 0.1028, 0.1891, 0.0993, -0.1132],
[-0.0641, 0.0048, -0.0417, -0.0937, 0.1667],
[-0.1644, -0.1126, 0.1683, -0.0912, -0.1931],
[ 0.0812, 0.1060, 0.0252, 0.0874, 0.1943]]],
[[[-0.0824, -0.1813, -0.0630, 0.0597, 0.0138],
[ 0.1787, -0.0708, -0.1641, 0.0792, 0.1893],
[ 0.0698, 0.1759, 0.1484, 0.1875, 0.1213],
[ 0.1149, 0.0910, 0.1125, -0.0134, 0.1492],
[ 0.1495, -0.1414, -0.0528, -0.1981, 0.1301]]],
[[[-0.1728, 0.0973, 0.1402, -0.1365, 0.0223],
[-0.1623, -0.1194, -0.0047, 0.1541, -0.1720],
[-0.1631, 0.1378, 0.1817, -0.0472, 0.1947],
[ 0.1213, -0.1986, -0.1415, 0.0990, -0.1460],
[-0.0185, 0.1019, 0.1401, 0.1501, -0.0396]]],
[[[ 0.1756, -0.0398, -0.1515, -0.0882, -0.0355],
[ 0.0670, 0.0649, 0.1082, 0.1635, 0.0461],
[ 0.1132, 0.0734, -0.0098, 0.0909, 0.0581],
[-0.1208, 0.1435, 0.1345, 0.1569, -0.1357],
[-0.1636, -0.0769, -0.1814, 0.1030, -0.0982]]],
[[[ 0.1056, 0.1005, -0.1637, 0.0073, -0.1990],
[-0.1068, 0.0294, -0.1704, 0.0204, -0.0671],
[ 0.0983, -0.0194, 0.1863, 0.0016, 0.0691],
[ 0.0071, -0.1114, 0.1376, -0.0950, 0.1963],
[-0.0081, 0.1638, 0.1705, 0.1929, 0.0170]]],
[[[-0.0803, -0.1001, 0.0113, 0.0315, -0.0165],
[-0.0962, 0.0289, 0.0582, 0.0729, 0.1783],
[-0.0604, 0.1725, 0.1990, -0.1395, -0.0585],
[ 0.1276, -0.0908, 0.0208, -0.1839, -0.0880],
[-0.0839, 0.1635, -0.0708, -0.1091, 0.0110]]],
[[[-0.1190, -0.0079, -0.0216, -0.1186, -0.0047],
[ 0.0116, 0.1451, -0.0582, -0.0773, 0.1828],
[-0.0389, -0.0040, -0.1912, -0.1790, -0.1594],
[-0.1674, -0.0784, -0.0527, -0.1289, -0.1193],
[-0.1707, -0.1716, 0.1906, 0.1985, -0.0731]]],
[[[ 0.1145, 0.1584, 0.0135, -0.0490, 0.1506],
[-0.1641, 0.0264, 0.1881, -0.0479, -0.0281],
[-0.0313, -0.0920, -0.0546, -0.1347, 0.0836],
[-0.0193, 0.1860, -0.0454, 0.1385, -0.1352],
[-0.1561, 0.0526, 0.1160, 0.0300, 0.1154]]],
[[[-0.1478, -0.1039, -0.1727, 0.0112, -0.0403],
[-0.0408, 0.1417, -0.0247, -0.0756, -0.1615],
[ 0.0093, -0.1076, -0.0906, -0.1624, 0.1284],
[-0.1374, 0.1409, -0.1797, -0.0802, 0.1416],
[ 0.0282, 0.1258, 0.0230, -0.0541, 0.0536]]],
[[[-0.0640, 0.0140, -0.1493, -0.0449, -0.0952],
[ 0.0109, -0.0120, -0.0977, -0.1969, -0.1897],
[ 0.1005, 0.0844, 0.1638, 0.0776, -0.1811],
[-0.0127, -0.1358, -0.1198, 0.0929, -0.0811],
[ 0.1094, 0.1268, 0.0769, -0.0800, 0.0134]]],
[[[ 0.1320, 0.1890, -0.1937, 0.0947, -0.1642],
[ 0.0090, -0.0804, -0.1137, -0.0412, -0.1253],
[-0.1924, 0.1154, 0.0567, -0.1458, -0.0735],
[-0.0265, 0.0895, -0.1165, -0.0549, 0.1763],
[-0.1959, -0.0329, -0.0194, 0.0983, -0.0659]]],
[[[-0.1925, -0.1036, 0.1780, -0.0791, -0.0873],
[ 0.0100, 0.1510, -0.1453, -0.0745, 0.0458],
[-0.0419, 0.0820, 0.1765, -0.1156, -0.0218],
[ 0.0933, 0.1453, -0.1843, -0.1624, 0.0401],
[ 0.0747, 0.0421, 0.1151, 0.1696, -0.0365]]],
[[[ 0.1575, 0.1559, -0.1104, -0.1436, -0.1991],
[-0.0338, -0.1194, -0.1659, 0.0048, -0.1487],
[ 0.0137, 0.0668, 0.0671, -0.0339, 0.0486],
[-0.0064, -0.0225, 0.0927, 0.0606, 0.0042],
[ 0.1252, -0.1965, 0.0352, 0.1180, 0.0896]]],
[[[ 0.0671, 0.1113, 0.0242, -0.0552, -0.1848],
[-0.1905, 0.0019, 0.0057, -0.0307, 0.1718],
[-0.0562, -0.1494, -0.1637, -0.1111, 0.0126],
[ 0.0556, 0.1048, 0.1284, 0.0417, -0.0556],
[-0.0655, 0.1431, -0.1373, 0.0311, -0.1628]]],
[[[ 0.0414, -0.0397, -0.0018, 0.1074, 0.1924],
[ 0.0205, 0.1236, -0.1880, 0.0947, -0.0946],
[-0.0543, -0.0087, 0.0633, -0.1134, 0.0912],
[ 0.0875, 0.0397, -0.1993, 0.1947, -0.1831],
[ 0.1359, 0.1628, -0.0632, -0.1867, -0.0839]]]]), 'conv1.bias': tensor([ 0.0303, 0.1421, 0.0549, -0.0636, 0.0261, -0.0631, 0.0215, 0.1342,
-0.0792, -0.1747, 0.0829, 0.1978, -0.1716, -0.1050, -0.0049, 0.1790,
0.1964, 0.0633, 0.0980, -0.0159, 0.0837, -0.1232, -0.0526, -0.1208,
-0.1421, -0.0880, 0.1810, -0.0636, 0.1665, 0.1121, -0.1900, 0.0091]), 'conv2.weight': tensor([[[[-1.0712e-02, -2.8543e-02, 2.5060e-03, 9.5242e-03, 3.1358e-02],
[-1.4135e-02, 1.3580e-02, 2.1455e-02, 1.2194e-02, 1.4131e-02],
[ 3.0316e-02, -1.1152e-02, -3.2848e-02, -2.2067e-02, -2.6107e-02],
[ 8.4410e-03, 3.9427e-03, 1.2867e-02, -1.2513e-03, -2.1535e-02],
[-3.9282e-03, 8.2564e-03, 1.4120e-02, 2.8508e-02, 3.9923e-03]],
[[-1.4297e-03, 4.7124e-03, -3.4464e-02, 1.4309e-02, -3.4432e-02],
[ 2.6338e-02, -4.2020e-03, 5.5271e-03, 3.1447e-02, 3.5195e-02],
[-9.7311e-03, 3.0671e-02, 3.3432e-02, -1.6397e-02, 9.7262e-03],
[-3.1700e-02, -2.4078e-02, -5.8298e-03, -7.3698e-03, -3.1840e-02],
[ 1.7487e-02, -2.2445e-02, -4.2725e-03, 1.6929e-02, -1.6501e-02]],
[[ 2.6720e-02, 1.5540e-02, 2.3545e-02, -3.1226e-03, 4.3139e-03],
[ 2.4492e-02, -1.6616e-02, -2.1790e-02, 3.0564e-02, -1.6427e-02],
[-1.7733e-02, -2.4438e-02, -1.9700e-02, 2.1084e-02, 2.9459e-02],
[-3.0856e-02, -3.5242e-02, 2.0413e-02, 4.1613e-03, 8.3422e-03],
[-1.5534e-02, -1.4875e-02, 1.6494e-02, -3.2325e-02, -1.1099e-04]],
...,
[[ 1.2833e-02, -2.0625e-02, -2.0700e-02, -1.3513e-02, 1.5524e-02],
[ 7.5736e-03, 3.2956e-02, 2.6385e-02, -2.1103e-02, -2.3895e-02],
[-1.9102e-02, 1.1805e-02, 1.1777e-02, -3.0465e-02, 2.6048e-03],
[ 6.3230e-04, 1.0886e-02, -1.5035e-02, -1.5262e-02, -2.5167e-02],
[-2.1168e-03, 2.6181e-02, 1.4023e-02, 2.2960e-02, 3.0476e-02]],
[[-2.1003e-02, -2.1934e-02, 1.6896e-02, -7.1724e-03, -2.8637e-02],
[ 2.8492e-03, 2.2367e-02, -3.2999e-02, -2.8547e-02, 2.5825e-02],
[-2.4395e-02, -1.2782e-02, 1.6746e-02, -1.8496e-02, -2.7374e-02],
[-2.0825e-03, 1.3699e-02, 2.0900e-02, -1.0655e-02, -2.0718e-02],
[ 2.3637e-03, -1.2933e-02, 1.3596e-03, -1.4176e-02, -2.7697e-02]],
[[ 1.6140e-02, -9.6101e-03, 1.6965e-02, -2.4911e-02, 1.7669e-02],
[ 6.7341e-03, 1.9680e-02, -2.4388e-02, 2.5657e-02, -3.5043e-02],
[-1.7948e-02, -1.9798e-02, 3.2972e-02, -1.0105e-02, 3.2288e-02],
[ 6.9848e-03, -2.0427e-02, 3.0102e-02, 2.0419e-02, 1.1636e-02],
[ 1.4831e-02, 2.5886e-02, 8.4296e-03, 3.6944e-03, 1.2589e-02]]],
[[[-1.5369e-03, 6.7277e-03, -2.8153e-03, -5.2977e-03, 7.7895e-04],
[ 2.6822e-02, -2.7268e-03, 8.9981e-03, -1.0177e-02, 3.8218e-03],
[-1.2254e-02, 7.7431e-03, -3.3513e-02, -3.2448e-02, -8.5315e-03],
[-1.4555e-02, 3.1427e-02, 2.7359e-02, -4.7352e-04, -2.2193e-02],
[ 5.6234e-03, -2.1741e-03, -1.8802e-02, -1.0976e-02, 2.7378e-02]],
[[ 7.5464e-03, -2.4656e-02, -3.2512e-02, -3.2849e-02, -1.5935e-02],
[-1.8062e-02, -8.6980e-03, -1.6742e-02, -2.0394e-03, -7.8879e-03],
[-1.1177e-02, 8.7528e-03, -3.4705e-02, -6.7506e-03, -8.8169e-03],
[-1.8519e-02, 7.6015e-03, 2.9804e-02, 3.7601e-03, -3.3281e-02],
[ 7.4764e-03, 3.5919e-03, 1.6526e-02, -8.7982e-03, -9.3495e-03]],
[[-8.7813e-03, 2.0255e-02, 1.6511e-02, 9.1172e-04, -3.7212e-03],
[ 3.2860e-02, -1.9111e-02, 1.5490e-03, -5.6712e-03, -1.6889e-02],
[ 7.5624e-03, -6.9371e-03, 3.1618e-02, -6.7844e-03, -3.1054e-02],
[ 7.0345e-03, -7.7054e-03, 3.5078e-02, 6.3236e-03, 2.3317e-02],
[-2.2862e-02, -9.7549e-03, 2.7260e-02, -3.3476e-02, 1.8389e-02]],
...,
[[ 1.3349e-02, 1.8076e-02, -2.1153e-03, 6.5682e-04, 2.4534e-02],
[ 5.3663e-03, 2.8427e-02, 7.8194e-03, -1.4124e-02, 3.0364e-02],
[ 3.1933e-02, -2.4390e-02, -3.3345e-02, 1.1310e-02, -2.0207e-02],
[-2.5572e-02, 2.6358e-02, 1.7217e-02, -2.9017e-03, 7.4605e-03],
[-3.0439e-02, 1.5487e-02, -5.8104e-03, -3.2419e-02, 6.8073e-03]],
[[ 1.1699e-02, 2.2438e-02, -1.2508e-02, -1.1145e-02, 1.1388e-02],
[-1.5566e-02, 1.7208e-02, -1.0435e-02, 9.3911e-03, 2.2554e-03],
[-7.6326e-03, -1.4475e-02, -7.9627e-04, 3.4089e-02, -2.1129e-02],
[ 8.4534e-05, -1.6221e-02, -4.5830e-03, -2.2959e-02, -2.0502e-02],
[-2.4321e-02, 1.4042e-02, 3.4342e-03, 2.5126e-02, 3.1417e-02]],
[[-7.1903e-03, -1.5285e-02, 3.4991e-02, -1.1870e-02, 3.3646e-02],
[ 2.5525e-02, 2.7944e-02, 3.1858e-02, 2.1613e-02, 2.5457e-02],
[-1.5631e-02, -1.9511e-02, 1.4821e-02, 6.1392e-03, -2.2879e-02],
[-2.0709e-03, 2.4683e-02, 1.5450e-02, 1.7543e-02, 3.0431e-02],
[ 1.2472e-02, 3.3912e-02, -3.3891e-02, -2.9483e-02, 2.1657e-02]]],
[[[ 1.2854e-03, 1.4983e-02, 2.6787e-03, -3.7954e-03, -3.8526e-04],
[-9.0427e-03, -2.3686e-02, -2.5989e-02, 2.9986e-02, -3.4829e-02],
[ 1.4774e-02, -2.5571e-03, 1.0485e-02, 5.6443e-03, -4.8553e-03],
[ 3.6432e-03, 1.7875e-02, 1.5348e-03, 1.8016e-02, 8.2804e-03],
[ 1.9742e-02, -1.5757e-02, -2.3739e-02, 1.6706e-02, -5.5210e-03]],
[[-1.4544e-02, 2.8706e-02, 3.0579e-02, -3.2698e-02, 9.7423e-03],
[-2.7827e-02, -3.6608e-03, 6.3911e-03, 1.1768e-03, -2.5861e-02],
[-3.3910e-02, 3.2610e-02, -1.0725e-02, -2.5239e-02, 1.3869e-02],
[-2.4907e-02, 2.2308e-02, 3.4435e-02, -1.1574e-02, -1.5687e-02],
[-3.3932e-02, -1.4322e-02, -9.2028e-03, 1.6489e-02, 2.7247e-02]],
[[ 1.9997e-02, 2.0339e-02, -2.4083e-02, 2.1822e-02, 3.5218e-02],
[ 4.6625e-03, -4.3648e-03, 1.3782e-02, -9.2227e-03, -1.9670e-02],
[-2.3750e-02, -1.6718e-02, -8.2103e-03, -3.0051e-02, 2.0756e-02],
[ 9.1907e-03, -2.6468e-02, 2.6651e-02, 2.2466e-02, 2.2550e-02],
[-2.5045e-02, -3.3377e-02, 2.4491e-02, -2.2864e-02, -6.7297e-03]],
...,
[[ 2.5862e-02, -1.0371e-03, -1.9383e-02, 1.7942e-03, 1.9761e-02],
[-2.8891e-02, 6.3965e-03, -2.6830e-02, 1.3699e-02, 1.0821e-02],
[ 1.1547e-02, -1.9258e-02, 1.4291e-02, -7.1339e-03, -2.1092e-02],
[ 1.1358e-02, 5.8365e-03, -2.8330e-02, -1.6591e-02, 1.4738e-02],
[-3.0109e-03, -3.1205e-02, 1.0713e-02, -2.7946e-02, 5.6631e-03]],
[[-3.4269e-03, 5.9247e-03, -2.2628e-02, 1.5790e-02, 1.6851e-02],
[-1.3199e-02, 1.1820e-02, -2.9882e-02, 1.5963e-02, 1.4160e-02],
[-1.8430e-02, 1.7088e-02, 1.0258e-02, -3.1797e-02, -1.5712e-02],
[ 5.8283e-03, -3.2654e-02, -1.2848e-02, 2.9440e-02, 1.5735e-02],
[ 1.5160e-02, -1.1311e-03, 1.5635e-02, -3.1450e-03, -2.3950e-02]],
[[ 1.1689e-02, -5.3986e-03, -1.7156e-02, 1.5808e-02, 1.1226e-02],
[ 6.9512e-03, 1.9596e-02, -9.9320e-03, -2.5242e-02, 2.5922e-02],
[-1.1149e-03, -1.8153e-02, -6.7535e-03, -4.0143e-03, 1.5343e-02],
[ 2.1245e-02, 2.0272e-02, -1.5746e-02, 4.4477e-03, 1.3009e-02],
[-2.7403e-02, -1.7578e-03, -5.7534e-03, -4.3350e-03, 3.2173e-03]]],
...,
[[[ 2.4695e-02, 2.4198e-02, 5.4242e-03, 1.7946e-03, 1.5525e-03],
[-6.9484e-03, -2.8010e-02, 6.0022e-03, -3.4202e-02, -9.2220e-03],
[-1.8714e-02, -3.3158e-02, 1.2717e-02, 2.1173e-02, 2.3357e-02],
[ 2.1218e-02, 1.3226e-02, 1.5477e-02, 1.4576e-02, -1.6706e-02],
[ 5.8316e-03, 1.7646e-02, 2.6505e-02, 1.6435e-02, -9.7523e-03]],
[[-7.5557e-03, 3.0235e-02, 4.0494e-03, -6.5395e-03, -2.7983e-02],
[-1.6704e-02, -1.4708e-02, -1.2753e-02, 2.0003e-02, 2.0317e-02],
[-1.4792e-02, -2.5440e-03, 8.3960e-03, 3.1746e-02, -4.2791e-03],
[-2.6947e-02, 2.8178e-02, -8.7998e-03, -1.1918e-03, 1.9409e-02],
[ 2.7548e-02, 2.9289e-02, -1.4868e-02, -2.4845e-02, 2.1959e-02]],
[[ 2.7115e-03, 2.1398e-02, -2.7235e-02, 2.4657e-02, 3.2983e-02],
[-2.9811e-02, 2.8511e-02, 2.6691e-02, 3.0088e-02, 5.6536e-03],
[ 1.2006e-02, 9.3720e-03, -8.6544e-04, -1.8001e-02, -4.6723e-04],
[ 1.5956e-02, 1.0558e-02, -2.6408e-03, -1.8055e-02, -1.8820e-02],
[ 2.0884e-02, -6.9533e-03, -2.7761e-02, -2.1180e-02, -1.0313e-02]],
...,
[[-2.5222e-02, -2.6723e-02, -2.5127e-02, -8.3920e-03, 1.2354e-02],
[ 2.4635e-02, -2.1187e-02, 3.2576e-02, 5.1753e-03, -1.5645e-02],
[-2.7097e-02, -1.3811e-02, -2.8127e-02, -7.5398e-03, -2.5397e-02],
[-3.2788e-02, 9.4662e-03, -2.5773e-02, -5.5557e-03, -2.1646e-02],
[-3.9811e-03, -2.9400e-02, -2.9801e-02, 3.4086e-03, -2.5995e-02]],
[[-1.7667e-02, -2.9269e-02, 3.3983e-02, 1.7904e-03, -1.1844e-02],
[-3.1558e-03, 2.8698e-02, -2.4786e-02, 2.8517e-02, -2.1105e-02],
[-1.3482e-02, -2.3590e-02, 1.5106e-02, -2.6257e-02, -3.1513e-02],
[ 1.5126e-02, 9.0866e-03, -3.5108e-03, -3.1232e-02, 1.0039e-02],
[ 1.0646e-02, 2.0490e-02, -1.6026e-02, 9.4491e-03, 1.8696e-03]],
[[ 9.5597e-03, -1.0937e-03, 3.1415e-03, -2.9728e-02, 2.4290e-02],
[ 2.1983e-02, -4.4185e-03, -4.3551e-03, -2.3103e-03, 2.8911e-02],
[-2.3258e-02, -2.7318e-02, 2.5071e-02, -2.8034e-02, -9.6178e-03],
[ 1.1631e-02, -6.4006e-03, 3.1090e-02, -2.6229e-02, -3.1959e-02],
[-1.8579e-03, 1.1335e-02, -1.9144e-02, 2.1692e-02, 7.2188e-03]]],
[[[-5.6811e-03, -1.0477e-02, 1.1886e-02, 7.3932e-03, 1.6800e-02],
[ 2.9957e-02, -4.6041e-03, 1.7368e-02, 2.9004e-02, -1.8263e-02],
[-2.6259e-02, 1.8272e-02, -8.9695e-03, 2.3765e-02, -3.3679e-02],
[-1.0965e-02, -7.6722e-03, 2.3450e-02, -1.0505e-02, 5.3181e-03],
[-5.6810e-03, -2.4764e-02, 3.1046e-02, -2.9747e-03, -2.8656e-02]],
[[ 2.0621e-02, 3.0689e-02, 9.9618e-03, -1.2074e-02, -3.4941e-02],
[ 2.5171e-02, 2.5641e-02, 2.3229e-02, -2.1664e-02, -1.0035e-02],
[-2.7126e-02, 1.5039e-03, -1.8666e-02, -1.3896e-03, -2.8527e-02],
[-3.8435e-03, 1.9811e-02, 2.6598e-02, 3.9880e-03, -1.2667e-02],
[-2.9899e-02, -1.0524e-04, -6.9346e-03, -9.0742e-03, -9.5847e-03]],
[[ 3.4221e-02, 1.8155e-02, -2.0950e-02, -3.0531e-02, -1.6531e-02],
[-2.6935e-02, -2.9924e-02, -1.2559e-02, -1.1806e-02, -3.4378e-02],
[-3.1124e-03, 6.3495e-03, -2.1526e-02, -2.1942e-02, 4.1308e-03],
[-4.0192e-03, -7.1271e-03, 3.2742e-02, 9.6951e-03, -8.8074e-03],
[ 5.0674e-03, -3.0229e-02, 1.2071e-02, -1.5985e-02, -1.9603e-02]],
...,
[[-2.6260e-02, -1.7585e-02, -2.1982e-02, -2.3320e-02, -4.1119e-03],
[ 1.3096e-02, -1.7109e-02, -1.3888e-02, -2.8812e-05, 2.6391e-02],
[-1.3052e-02, -1.1130e-02, -2.9985e-02, 1.2317e-02, -3.1856e-02],
[-8.2432e-03, 2.8641e-02, 2.4846e-02, 5.5159e-03, 1.9084e-02],
[-3.4092e-02, -2.1065e-02, 2.2432e-02, -1.6194e-02, 2.2492e-02]],
[[ 3.3804e-02, 1.4972e-02, 2.2994e-03, 3.1839e-02, -1.3227e-02],
[-3.3975e-02, -3.4533e-02, -1.6026e-02, 2.2788e-03, -1.4643e-03],
[ 3.2438e-02, -2.8320e-02, -1.8481e-02, 5.8380e-03, -5.3999e-03],
[ 9.1915e-03, 3.6022e-03, 1.1685e-04, -2.4490e-02, 1.3981e-02],
[ 5.1616e-03, -6.7223e-03, 2.9258e-02, 3.0399e-02, -1.0489e-02]],
[[-2.0869e-02, -2.7418e-02, 9.2013e-03, 1.0312e-03, -1.3312e-02],
[ 7.1380e-05, 1.8098e-02, 4.8561e-03, 2.6030e-02, 4.5902e-03],
[-3.8396e-06, -9.0726e-03, -3.4657e-02, -3.5020e-02, 1.7769e-02],
[ 4.7196e-03, -3.0351e-04, 6.6702e-03, -1.5387e-02, -1.5521e-02],
[ 2.0964e-03, 1.5412e-02, 2.2774e-02, -4.0799e-03, 2.0905e-02]]],
[[[ 1.6501e-02, -2.7081e-02, 9.0558e-03, 2.5332e-03, 1.2791e-02],
[-7.0474e-03, 1.6052e-02, 2.9610e-02, -4.8062e-03, 2.2890e-02],
[-2.1236e-02, -8.2819e-03, -3.3545e-02, 3.3778e-02, 2.5133e-02],
[ 3.3057e-02, -1.5296e-02, 1.7353e-02, -4.8650e-03, 3.4039e-02],
[ 2.9383e-02, 2.2072e-02, -1.3218e-02, 2.5207e-02, -3.1896e-02]],
[[ 1.3995e-02, -2.0663e-02, -7.7605e-03, -2.9423e-02, 1.3063e-02],
[ 2.5140e-02, -9.7825e-03, -2.2534e-02, 1.2679e-02, -2.3407e-03],
[ 3.0757e-02, 1.2600e-02, -9.9360e-03, -2.9706e-02, 3.0537e-02],
[ 2.1376e-02, 2.1465e-02, 1.6579e-02, -2.8762e-02, -2.8087e-02],
[ 2.5223e-02, -1.7151e-02, 7.6622e-03, 3.3316e-02, 9.0349e-03]],
[[-3.2144e-02, -3.4602e-02, -1.2078e-03, 2.5526e-02, -2.1524e-02],
[-3.4233e-02, 8.9771e-03, 3.4649e-02, -2.9127e-02, -2.1181e-02],
[-2.3100e-02, 3.9237e-03, -2.6253e-02, -3.1718e-02, 3.5719e-03],
[-1.8098e-02, 2.5035e-02, -1.7552e-02, 2.4375e-02, -2.1021e-02],
[-3.5218e-02, -2.3443e-02, 3.5088e-02, -1.7220e-03, 2.7329e-03]],
...,
[[ 2.0920e-02, -2.8072e-02, 2.0046e-02, 2.1542e-02, 2.8214e-02],
[-1.4556e-02, -2.8212e-02, -8.9851e-03, -3.3224e-02, 2.4021e-02],
[-2.6013e-02, -3.1922e-02, -1.5481e-04, -1.2796e-02, 5.5097e-03],
[ 3.1735e-02, 2.0218e-02, 2.9766e-02, 2.7779e-02, 3.0098e-02],
[ 1.8398e-03, -2.1543e-02, -2.9273e-02, -5.8921e-03, 1.6662e-02]],
[[ 2.8385e-02, -1.8564e-02, -6.6080e-03, 2.5015e-02, -3.3667e-03],
[-1.9641e-02, -3.4182e-02, 2.0578e-02, -2.0450e-02, 3.3780e-03],
[ 6.3422e-04, 2.6127e-02, 1.1615e-02, 2.5706e-02, -2.4499e-02],
[ 1.9641e-02, 3.2078e-02, -2.9023e-02, -3.4537e-02, 1.1839e-02],
[-1.1807e-03, 3.3522e-02, -2.9450e-02, -3.3327e-02, -3.3981e-02]],
[[-2.3388e-03, -8.2398e-03, -1.6055e-02, -1.2572e-03, -1.9137e-02],
[-2.5775e-02, 1.8130e-03, -4.7393e-04, 1.8243e-02, -1.0023e-02],
[ 8.7671e-03, -2.4885e-02, -2.3222e-02, 3.0048e-02, -2.0407e-02],
[-2.8137e-02, 3.1120e-02, -7.6599e-03, -1.9271e-02, 2.2285e-02],
[-2.9455e-02, 1.2270e-02, -2.2933e-02, 6.5161e-03, 2.0707e-02]]]]), 'conv2.bias': tensor([-0.0294, 0.0329, 0.0288, 0.0177, -0.0267, 0.0292, -0.0254, 0.0350,
0.0102, 0.0154, 0.0301, 0.0296, 0.0310, -0.0298, 0.0040, 0.0274,
0.0159, 0.0237, -0.0041, 0.0274, 0.0122, -0.0251, 0.0142, 0.0245,
0.0126, 0.0059, 0.0148, 0.0251, 0.0234, 0.0009, -0.0098, -0.0203,
0.0108, -0.0250, 0.0113, 0.0071, 0.0188, 0.0034, 0.0039, -0.0132,
0.0325, 0.0291, -0.0115, 0.0208, 0.0073, -0.0019, 0.0264, -0.0297,
-0.0012, -0.0327, -0.0204, -0.0143, -0.0182, 0.0242, 0.0229, -0.0135,
-0.0017, 0.0063, 0.0077, -0.0157, -0.0293, 0.0144, 0.0262, 0.0104]), 'fc1.weight': tensor([[-0.0081, 0.0138, -0.0154, ..., 0.0130, 0.0123, 0.0151],
[-0.0053, 0.0166, -0.0099, ..., -0.0101, -0.0069, -0.0051],
[-0.0101, 0.0112, -0.0009, ..., 0.0026, -0.0106, 0.0138],
...,
[-0.0068, -0.0142, -0.0123, ..., -0.0109, -0.0112, -0.0102],
[-0.0028, 0.0046, 0.0162, ..., -0.0176, 0.0086, 0.0032],
[ 0.0050, -0.0094, -0.0085, ..., -0.0165, -0.0068, 0.0035]]), 'fc1.bias': tensor([-3.4672e-03, -1.4446e-02, -1.2293e-02, 7.7690e-03, -7.4937e-03,
-1.2601e-02, 7.4108e-03, -8.5502e-03, 1.6310e-02, 7.2417e-03,
-1.0148e-02, 1.6993e-02, -9.7666e-04, -4.3463e-03, -1.1272e-02,
-4.1479e-03, 1.7398e-02, 1.3415e-02, -1.7631e-02, 8.8416e-03,
-3.1741e-03, -1.4023e-02, -1.5655e-02, 1.0000e-02, -9.5185e-03,
-3.8707e-03, 2.1299e-03, 1.2721e-03, -1.0397e-02, -8.5392e-03,
-1.2514e-02, -2.3353e-03, 7.8897e-03, -7.6218e-03, 1.2260e-02,
-1.6806e-02, -7.9503e-03, 8.0836e-03, 1.1840e-02, 2.2876e-03,
-2.4980e-03, 3.8789e-03, -1.4930e-02, 1.4448e-02, 1.6045e-02,
-6.4406e-03, 6.9938e-03, -1.5074e-02, -3.4915e-04, 6.6718e-03,
-3.5812e-03, -1.6976e-02, 1.2715e-03, -6.0759e-03, -9.5487e-03,
1.6535e-02, 1.2655e-02, -1.3646e-02, -1.2447e-02, -6.4641e-03,
5.3294e-03, 5.6371e-03, -4.3157e-03, -3.4694e-03, -1.6611e-02,
-8.5411e-04, -8.5772e-03, -8.4273e-03, -1.5747e-02, -1.3618e-02,
1.0321e-02, -9.3956e-03, 1.0570e-02, 1.7520e-02, 6.9964e-03,
-4.6320e-03, 7.5614e-03, -6.2394e-03, -8.5712e-03, 1.5812e-02,
-1.5301e-02, 9.6769e-03, -1.3045e-02, -1.3433e-02, -9.7229e-04,
1.7275e-02, -7.5429e-03, 4.2608e-04, -1.4852e-03, -3.8250e-03,
1.3177e-02, -4.2672e-03, 2.9165e-03, 3.5423e-04, -1.6563e-02,
-1.7646e-02, -4.8865e-03, 1.1881e-02, -5.0371e-03, 9.3326e-03,
4.5758e-03, -4.5849e-03, 3.7344e-03, -1.8454e-03, -1.6846e-02,
8.2546e-03, -7.1566e-03, 1.4772e-03, -1.6290e-02, -1.0622e-02,
1.0886e-02, -5.8009e-03, 1.7793e-02, -8.6404e-03, -2.6911e-03,
1.3075e-02, -7.9632e-03, -7.3142e-03, -9.1669e-03, 7.8864e-04,
-1.3171e-02, 1.4579e-02, 1.1616e-02, 3.9550e-03, 1.1550e-02,
-1.5605e-02, -8.4229e-03, 1.3751e-02, 3.6031e-03, 1.5572e-02,
-1.7369e-03, -8.5769e-03, 5.8602e-03, -6.0227e-03, 1.6866e-02,
-1.5111e-02, 1.4616e-02, -6.7068e-03, -1.1656e-02, 4.7307e-03,
1.5767e-02, 1.1070e-02, -8.8410e-03, 1.7600e-02, -2.0084e-03,
6.8243e-03, 1.2983e-02, 5.2070e-04, 1.0046e-02, 1.3286e-02,
8.9343e-03, 4.8149e-03, -2.5697e-03, 4.4682e-03, 1.6287e-02,
6.3040e-03, -8.4443e-03, -5.3058e-04, -4.3037e-03, 1.5347e-02,
2.1996e-04, -1.0720e-02, -5.8503e-03, 1.3797e-02, -1.4177e-02,
1.1434e-02, -1.8945e-03, 1.6068e-02, 1.7447e-03, -1.3956e-02,
6.2457e-03, -1.0211e-02, 1.5087e-02, 9.6760e-03, 1.3598e-02,
1.7340e-02, -1.2309e-02, -6.7817e-03, 1.2550e-02, 1.3340e-02,
1.6681e-02, 9.6821e-03, -1.1964e-02, 1.2771e-02, -6.4358e-03,
1.2654e-02, -1.7382e-02, -3.5477e-03, -1.7585e-02, 4.1828e-03,
9.5395e-03, -1.2341e-02, -8.0899e-03, 1.7100e-02, -2.0383e-03,
7.7255e-03, 1.4668e-02, -1.1553e-02, -1.3684e-03, -1.0668e-02,
4.9831e-03, -4.1533e-03, 1.0619e-02, 4.4827e-03, -1.1317e-02,
5.3828e-03, 7.2284e-03, 1.2856e-02, -6.4634e-03, -1.1901e-02,
1.3786e-02, 9.3409e-03, -7.5928e-03, 5.4179e-04, 2.0796e-04,
1.4698e-02, 1.3254e-02, -1.1621e-02, -2.6928e-03, -9.9327e-03,
1.7629e-02, 7.1257e-03, -1.2520e-02, -4.3111e-03, 7.7188e-03,
-8.9904e-03, -5.9841e-03, -1.7572e-03, 8.8026e-03, -1.9239e-03,
1.4128e-02, 9.7155e-03, 1.4960e-02, -3.0571e-03, -1.4444e-02,
1.2553e-02, 1.1271e-02, -8.8978e-03, -1.3108e-02, 1.2628e-02,
5.6482e-03, -4.5838e-03, 9.3955e-03, -1.2634e-02, 2.5492e-03,
-1.0865e-02, -1.1644e-02, -1.2602e-02, -1.5807e-03, -8.9658e-03,
5.1082e-03, -3.3411e-04, 8.6929e-04, 8.9536e-03, 1.0715e-02,
-2.4002e-03, -1.6245e-02, 3.0127e-03, 1.2196e-02, -1.6267e-02,
1.6278e-02, -4.9497e-03, 1.5032e-02, 4.7426e-04, -6.2285e-04,
-2.2680e-03, -8.8868e-04, -1.2714e-03, 1.1415e-03, -4.5226e-03,
-1.5853e-02, -7.4868e-03, -1.0161e-02, 1.7643e-02, -1.5002e-02,
1.7216e-02, -5.0324e-03, -1.0926e-02, 9.9244e-03, 1.3024e-02,
1.0218e-02, -4.3209e-03, -4.2856e-03, 1.2696e-02, 2.9352e-03,
2.7632e-03, 9.8186e-03, 4.4106e-03, -1.7612e-02, 4.3815e-03,
8.0082e-03, 2.2632e-03, -8.4109e-03, -1.3274e-02, 1.1617e-02,
-1.7727e-02, 5.5763e-03, -1.0286e-02, 1.1968e-02, -1.5516e-02,
6.6903e-03, 6.7595e-03, -2.5033e-03, -1.1838e-02, 2.0463e-03,
1.1892e-02, -2.2310e-03, 1.5878e-02, -1.6940e-02, 6.4767e-03,
1.7238e-02, 1.5441e-02, 1.2099e-02, 1.1450e-02, 1.6676e-02,
-3.4195e-03, -1.2476e-03, -1.3253e-03, 1.3067e-02, -1.3566e-03,
6.8635e-03, -5.5605e-03, 6.0657e-03, -1.2112e-02, 9.5660e-03,
-1.3109e-03, 5.1050e-03, 1.4025e-02, -1.4562e-02, 3.0868e-03,
9.3931e-03, 1.7545e-02, -1.5243e-02, -1.2314e-02, -4.1206e-04,
1.7688e-02, -1.7570e-02, 1.7019e-02, -1.7788e-02, 1.6966e-02,
4.0122e-03, 3.0628e-03, -4.0461e-03, 1.4157e-02, 3.9072e-03,
6.3313e-03, -1.3319e-02, -1.1896e-02, 3.6852e-03, 1.1832e-02,
-8.7784e-03, 1.1129e-02, 1.3978e-02, 1.0630e-02, 3.6990e-03,
4.4645e-03, -1.6836e-02, -1.3500e-02, 1.6876e-02, 1.4516e-02,
7.4516e-03, -1.6084e-02, 1.5842e-02, 1.2070e-02, 1.5367e-02,
3.1857e-03, -1.7789e-03, -1.4422e-02, -1.4149e-02, 8.2015e-03,
8.7930e-03, -7.6206e-03, 4.6303e-03, -9.0639e-04, -7.6241e-03,
6.0896e-03, 7.6024e-03, -1.5949e-02, -1.1160e-03, -3.8057e-03,
1.1965e-03, -9.5787e-03, 1.1893e-02, 4.2951e-03, 4.4890e-03,
-1.1108e-02, -5.2652e-03, -6.9700e-04, 1.3596e-02, 5.8716e-03,
-7.8927e-03, -9.4266e-03, 3.3122e-03, 1.5078e-02, 4.2493e-03,
-1.2647e-02, 1.6407e-02, -4.4845e-03, -8.5834e-03, -9.3776e-03,
7.0308e-03, 1.5408e-02, -1.5356e-02, -6.5015e-03, 3.8347e-03,
4.1556e-03, -1.1227e-02, -1.4538e-03, 4.5388e-03, -1.5766e-02,
-5.1742e-03, 5.2968e-03, -7.3040e-03, 2.2809e-03, 2.1299e-03,
-1.6927e-02, -1.1344e-02, -1.1302e-02, -1.5232e-02, 1.3569e-02,
-3.4408e-03, -1.4370e-02, 1.5899e-02, 5.6999e-03, 2.5900e-03,
-9.2822e-03, -4.7884e-03, 1.2711e-02, 1.4953e-02, -9.6008e-03,
-7.8154e-03, 2.3049e-03, -7.2286e-03, 3.4406e-03, -1.4979e-02,
1.3255e-02, -1.5416e-02, 4.7037e-03, -6.2464e-05, 6.1763e-03,
-1.7083e-02, -4.1979e-03, 4.8330e-03, 1.6848e-02, -5.8141e-03,
1.0530e-02, 1.2660e-02, 7.7921e-03, 1.2516e-02, -8.2558e-03,
9.6033e-03, 7.4281e-04, 1.6809e-02, -1.2299e-02, 1.1192e-02,
-1.2419e-02, 1.7704e-02, -2.2003e-03, -5.5301e-03, -1.1976e-02,
-7.5681e-03, 1.7068e-02, 1.3416e-02, 1.0705e-02, -1.8727e-03,
1.1100e-02, 1.7324e-02, 4.0332e-03, 5.9611e-04, 1.3360e-02,
1.2185e-02, 7.5230e-03, 1.5142e-02, 1.4654e-02, 1.2797e-02,
1.6562e-02, 5.9504e-03, 6.2322e-03, 1.6638e-02, 3.0088e-04,
-1.5574e-02, 1.0657e-02, -5.7672e-03, 1.6510e-02, -1.1042e-02,
-1.4875e-02, 1.5410e-02, -1.6385e-02, -1.1161e-03, 7.7549e-03,
1.1378e-02, 1.0371e-02, -2.7560e-03, 1.3848e-02, -1.2284e-02,
-2.9374e-03, 1.0240e-02, -1.2988e-02, -8.2888e-03, -1.4185e-02,
-1.2491e-02, -7.7231e-03, 1.1543e-02, 1.4141e-02, 6.7815e-03,
1.4062e-02, 5.3020e-04, 5.9166e-03, 7.7286e-03, 8.2705e-03,
1.2781e-03, -1.4400e-02]), 'fc2.weight': tensor([[ 0.0173, -0.0204, -0.0246, ..., -0.0214, 0.0191, 0.0375],
[ 0.0109, -0.0418, -0.0442, ..., -0.0027, 0.0389, 0.0407],
[ 0.0440, 0.0031, -0.0332, ..., 0.0294, -0.0359, -0.0237],
...,
[-0.0346, -0.0161, 0.0228, ..., -0.0070, -0.0313, -0.0109],
[ 0.0340, -0.0133, -0.0414, ..., 0.0021, -0.0173, 0.0435],
[ 0.0281, -0.0380, 0.0440, ..., -0.0395, -0.0356, -0.0384]]), 'fc2.bias': tensor([ 0.0387, -0.0387, 0.0442, 0.0035, -0.0139, 0.0319, 0.0112, -0.0021,
0.0198, -0.0422, -0.0340, -0.0312, -0.0236, 0.0207, -0.0330, -0.0092,
0.0038, 0.0132, 0.0357, 0.0131, -0.0439, 0.0403, 0.0326, -0.0034,
0.0372, -0.0296, -0.0147, -0.0009, 0.0266, -0.0027, -0.0150, 0.0015,
0.0336, -0.0176, -0.0119, 0.0077, -0.0043, 0.0044, 0.0299, 0.0308,
-0.0223, -0.0223, 0.0140, 0.0130, -0.0155, 0.0175, -0.0068, 0.0380,
0.0010, -0.0395, 0.0146, -0.0289, 0.0194, -0.0129, 0.0116, 0.0291,
-0.0225, -0.0211, 0.0376, 0.0088, -0.0227, -0.0254])}
%% Cell type:code id: tags:
```
a = [(3, 2), (2, 5), (2, 6)]
a.sort(reverse = True)
```
%% Cell type:code id: tags:
```
```
......
import argparse
import datetime
import logging
from pathlib import Path
from shutil import copy
from localconfig import LocalConfig
from torch import multiprocessing as mp
from decentralizepy import utils
from decentralizepy.graphs.Graph import Graph
from decentralizepy.mappings.Linear import Linear
from decentralizepy.node.Node import Node
......@@ -22,23 +22,10 @@ def read_ini(file_path):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-mid", "--machine_id", type=int, default=0)
parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
parser.add_argument("-ms", "--machines", type=int, default=1)
parser.add_argument(
"-ld", "--log_dir", type=str, default="./{}".format(datetime.datetime.now())
)
parser.add_argument("-is", "--iterations", type=int, default=1)
parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
parser.add_argument("-ll", "--log_level", type=str, default="INFO")
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
args = parser.parse_args()
args = utils.get_args()
Path(args.log_dir).mkdir(parents=True, exist_ok=True)
log_level = {
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
......@@ -52,6 +39,10 @@ if __name__ == "__main__":
for section in config:
my_config[section] = dict(config.items(section))
copy(args.config_file, args.log_dir)
copy(args.graph_file, args.log_dir)
utils.write_args(args, args.log_dir)
g = Graph()
g.read_graph_from_file(args.graph_file, args.graph_type)
n_machines = args.machines
......
......@@ -141,6 +141,9 @@ class Node:
sharing_configs = config["SHARING"]
sharing_package = importlib.import_module(sharing_configs["sharing_package"])
sharing_class = getattr(sharing_package, sharing_configs["sharing_class"])
sharing_params = utils.remove_keys(
sharing_configs, ["sharing_package", "sharing_class"]
)
self.sharing = sharing_class(
self.rank,
self.machine_id,
......@@ -149,6 +152,7 @@ class Node:
self.graph,
self.model,
self.dataset,
**sharing_params
)
self.testset = self.dataset.get_testset()
......
import json
import math
import logging
import numpy
import torch
......@@ -17,6 +17,7 @@ class PartialModel(Sharing):
self.alpha = alpha
def extract_sorted_gradients(self):
logging.info("Summing up gradients")
assert len(self.model.accumulated_gradients) > 0
gradient_sum = self.model.accumulated_gradients[0]
for i in range(1, len(self.model.accumulated_gradients)):
......@@ -24,28 +25,42 @@ class PartialModel(Sharing):
gradient_sum[key] += self.model.accumulated_gradients[i][key]
gradient_sequence = []
logging.info("Turning gradients into tuples")
for key, gradient in gradient_sum.items():
for index, val in enumerate(torch.flatten(gradient)):
gradient_sequence.append((val, key, index))
gradient_sequence.sort()
logging.info("Sorting gradient tuples")
gradient_sequence.sort() # bottleneck
return gradient_sequence
def serialized_model(self):
gradient_sequence = self.extract_sorted_gradients()
logging.info("Extracted sorted gradients")
gradient_sequence = gradient_sequence[
: math.round(len(gradient_sequence) * self.alpha)
: round(len(gradient_sequence) * self.alpha)
]
m = dict()
for _, key, index in gradient_sequence:
if key not in m:
m[key] = []
m[key].append(index, torch.flatten(self.model.state_dict()[key])[index])
m[key].append(
(
index,
torch.flatten(self.model.state_dict()[key])[index].numpy().tolist(),
)
)
logging.info("Generated dictionary to send")
for key in m:
m[key] = json.dumps(m[key])
logging.info("Converted dictionary to json")
return m
def deserialized_model(self, m):
......@@ -54,5 +69,4 @@ class PartialModel(Sharing):
for key, value in m.items():
for index, param_val in json.loads(value):
torch.flatten(state_dict[key])[index] = param_val
state_dict[key] = torch.from_numpy(numpy.array(json.loads(value)))
return state_dict
......@@ -7,7 +7,7 @@ class GradientAccumulator(Training):
def __init__(
self, model, optimizer, loss, epochs_per_round="", batch_size="", shuffle=""
):
super().__init__()
super().__init__(model, optimizer, loss, epochs_per_round, batch_size, shuffle)
def train(self, dataset):
"""
......@@ -30,9 +30,11 @@ class GradientAccumulator(Training):
epoch_loss += loss_val.item()
loss_val.backward()
self.model.accumulated_gradients.append(
grad_dict={
{
k: v.grad.clone().detach()
for k, v in zip(self.model.state_dict(), self.parameters())
for k, v in zip(
self.model.state_dict(), self.model.parameters()
)
}
)
self.optimizer.step()
......
import argparse
import datetime
import json
import os
def conditional_value(var, nul, default):
if var != nul:
return var
......@@ -7,3 +13,37 @@ def conditional_value(var, nul, default):
def remove_keys(d, keys_to_remove):
return {key: d[key] for key in d if key not in keys_to_remove}
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("-mid", "--machine_id", type=int, default=0)
parser.add_argument("-ps", "--procs_per_machine", type=int, default=1)
parser.add_argument("-ms", "--machines", type=int, default=1)
parser.add_argument(
"-ld", "--log_dir", type=str, default="./{}".format(datetime.datetime.now())
)
parser.add_argument("-is", "--iterations", type=int, default=1)
parser.add_argument("-cf", "--config_file", type=str, default="config.ini")
parser.add_argument("-ll", "--log_level", type=str, default="INFO")
parser.add_argument("-gf", "--graph_file", type=str, default="36_nodes.edges")
parser.add_argument("-gt", "--graph_type", type=str, default="edges")
args = parser.parse_args()
return args
def write_args(args, path):
data = {
"machine_id": args.machine_id,
"procs_per_machine": args.procs_per_machine,
"machines": args.machines,
"log_dir": args.log_dir,
"iterations": args.iterations,
"config_file": args.config_file,
"log_level": args.log_level,
"graph_file": args.graph_file,
"graph_type": args.graph_type,
}
with open(os.path.join(path, "args.json"), "w") as of:
json.dump(data, of)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment