forked from tangger/lerobot
@@ -11,22 +11,24 @@ from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context, set_global_seed
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"rand_fn",
|
||||
(
|
||||
[
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
+ [lambda: torch.rand(1, device="cuda")]
|
||||
if torch.cuda.is_available()
|
||||
else []
|
||||
),
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
seeded_context,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
# Random generation functions for testing the seeding and random state get/set.
|
||||
rand_fns = [
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
rand_fns.append(lambda: torch.rand(1, device="cuda"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rand_fn", rand_fns)
|
||||
def test_seeding(rand_fn: Callable[[], int]):
|
||||
set_global_seed(0)
|
||||
a = rand_fn()
|
||||
@@ -46,6 +48,15 @@ def test_seeding(rand_fn: Callable[[], int]):
|
||||
assert c_ == c
|
||||
|
||||
|
||||
def test_get_set_random_state():
|
||||
"""Check that getting the random state, then setting it results in the same random number generation."""
|
||||
random_state_dict = get_global_random_state()
|
||||
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
||||
set_global_random_state(random_state_dict)
|
||||
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
||||
assert rand_numbers_ == rand_numbers
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user