Organize test folders (#856)
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
122
tests/utils/test_random_utils.py
Normal file
122
tests/utils/test_random_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.random_utils import (
|
||||
deserialize_numpy_rng_state,
|
||||
deserialize_python_rng_state,
|
||||
deserialize_rng_state,
|
||||
deserialize_torch_rng_state,
|
||||
get_rng_state,
|
||||
seeded_context,
|
||||
serialize_numpy_rng_state,
|
||||
serialize_python_rng_state,
|
||||
serialize_rng_state,
|
||||
serialize_torch_rng_state,
|
||||
set_rng_state,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixed_seed():
|
||||
"""Fixture to set a consistent initial seed for each test."""
|
||||
set_seed(12345)
|
||||
yield
|
||||
|
||||
|
||||
def test_serialize_deserialize_python_rng(fixed_seed):
|
||||
# Save state after generating val1
|
||||
_ = random.random()
|
||||
st = serialize_python_rng_state()
|
||||
# Next random is val2
|
||||
val2 = random.random()
|
||||
# Restore the state, so the next random should match val2
|
||||
deserialize_python_rng_state(st)
|
||||
val3 = random.random()
|
||||
assert val2 == val3
|
||||
|
||||
|
||||
def test_serialize_deserialize_numpy_rng(fixed_seed):
|
||||
_ = np.random.rand()
|
||||
st = serialize_numpy_rng_state()
|
||||
val2 = np.random.rand()
|
||||
deserialize_numpy_rng_state(st)
|
||||
val3 = np.random.rand()
|
||||
assert val2 == val3
|
||||
|
||||
|
||||
def test_serialize_deserialize_torch_rng(fixed_seed):
|
||||
_ = torch.rand(1).item()
|
||||
st = serialize_torch_rng_state()
|
||||
val2 = torch.rand(1).item()
|
||||
deserialize_torch_rng_state(st)
|
||||
val3 = torch.rand(1).item()
|
||||
assert val2 == val3
|
||||
|
||||
|
||||
def test_serialize_deserialize_rng(fixed_seed):
|
||||
# Generate one from each library
|
||||
_ = random.random()
|
||||
_ = np.random.rand()
|
||||
_ = torch.rand(1).item()
|
||||
# Serialize
|
||||
st = serialize_rng_state()
|
||||
# Generate second set
|
||||
val_py2 = random.random()
|
||||
val_np2 = np.random.rand()
|
||||
val_th2 = torch.rand(1).item()
|
||||
# Restore, so the next draws should match val_py2, val_np2, val_th2
|
||||
deserialize_rng_state(st)
|
||||
assert random.random() == val_py2
|
||||
assert np.random.rand() == val_np2
|
||||
assert torch.rand(1).item() == val_th2
|
||||
|
||||
|
||||
def test_get_set_rng_state(fixed_seed):
|
||||
st = get_rng_state()
|
||||
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
# Change states
|
||||
random.random()
|
||||
np.random.rand()
|
||||
torch.rand(1)
|
||||
# Restore
|
||||
set_rng_state(st)
|
||||
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
assert val1 == val2
|
||||
|
||||
|
||||
def test_set_seed():
|
||||
set_seed(1337)
|
||||
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
set_seed(1337)
|
||||
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
assert val1 == val2
|
||||
|
||||
|
||||
def test_seeded_context(fixed_seed):
|
||||
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
with seeded_context(1337):
|
||||
seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
with seeded_context(1337):
|
||||
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
|
||||
assert seeded_val1 == seeded_val2
|
||||
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
|
||||
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting
|
||||
Reference in New Issue
Block a user