Add context manager for seeding (#164)
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import logging
|
||||
import os.path as osp
|
||||
import random
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
@@ -39,6 +41,31 @@ def set_global_seed(seed):
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def seeded_context(seed: int) -> Generator[None, None, None]:
|
||||
"""Set the seed when entering a context, and restore the prior random state at exit.
|
||||
|
||||
Example usage:
|
||||
|
||||
```
|
||||
a = random.random() # produces some random number
|
||||
with seeded_context(1337):
|
||||
b = random.random() # produces some other random number
|
||||
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
|
||||
```
|
||||
"""
|
||||
random_state = random.getstate()
|
||||
np_random_state = np.random.get_state()
|
||||
torch_random_state = torch.random.get_rng_state()
|
||||
torch_cuda_random_state = torch.cuda.random.get_rng_state()
|
||||
set_global_seed(seed)
|
||||
yield None
|
||||
random.setstate(random_state)
|
||||
np.random.set_state(np_random_state)
|
||||
torch.random.set_rng_state(torch_random_state)
|
||||
torch.cuda.random.set_rng_state(torch_cuda_random_state)
|
||||
|
||||
|
||||
def init_logging():
|
||||
def custom_format(record):
|
||||
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
Reference in New Issue
Block a user