Skip to content
Snippets Groups Projects

Choco compression fix

Open Milos Vujasinovic requested to merge mvujas/decentralizepy:choco-compression-fix into main
import random
import contextlib
import torch
import numpy as np
@@ -14,13 +16,16 @@ def temp_seed(seed):
on CPU regardless if CUDA is used for other things.
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
torch.random.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
try:
yield
finally:
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
@@ -33,8 +38,17 @@ class RandomState:
"""
def __init__(self, seed):
with temp_seed(seed):
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
self.__refresh_states()
def __refresh_states(self):
self.__random_state = random.getstate()
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
def __set_states(self):
random.setstate(self.__random_state)
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
@contextlib.contextmanager
def activate(self):
@@ -44,14 +58,14 @@ class RandomState:
is finished
"""
random_state = random.getstate()
np_old_state = np.random.get_state()
torch_old_state = torch.random.get_rng_state()
np.random.set_state(self.__np_state)
torch.random.set_rng_state(self.__torch_state)
self.__set_states()
try:
yield
finally:
self.__np_state = np.random.get_state()
self.__torch_state = torch.random.get_rng_state()
self.__refresh_states()
random.setstate(random_state)
np.random.set_state(np_old_state)
torch.random.set_rng_state(torch_old_state)
\ No newline at end of file
Loading