Add context manager for seeding (#164)

This commit is contained in:
Alexander Soare
2024-05-09 17:58:39 +01:00
committed by GitHub
parent 473345fdf6
commit b187942db4
2 changed files with 65 additions and 0 deletions

View File

@@ -1,8 +1,10 @@
import logging
import os.path as osp
import random
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Generator
import hydra
import numpy as np
@@ -39,6 +41,31 @@ def set_global_seed(seed):
torch.cuda.manual_seed_all(seed)
@contextmanager
def seeded_context(seed: int) -> Generator[None, None, None]:
"""Set the seed when entering a context, and restore the prior random state at exit.
Example usage:
```
a = random.random() # produces some random number
with seeded_context(1337):
b = random.random() # produces some other random number
c = random.random() # produces yet another random number, but the same it would have if we never made `b`
```
"""
random_state = random.getstate()
np_random_state = np.random.get_state()
torch_random_state = torch.random.get_rng_state()
torch_cuda_random_state = torch.cuda.random.get_rng_state()
set_global_seed(seed)
yield None
random.setstate(random_state)
np.random.set_state(np_random_state)
torch.random.set_rng_state(torch_random_state)
torch.cuda.random.set_rng_state(torch_cuda_random_state)
def init_logging():
def custom_format(record):
dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S")