Add resume training (#205)

Co-authored-by: Remi <re.cadene@gmail.com>
This commit is contained in:
Alexander Soare
2024-05-28 12:04:23 +01:00
committed by GitHub
parent 7ec76ee235
commit e3b9f1c19b
15 changed files with 486 additions and 191 deletions

View File

@@ -11,22 +11,24 @@ from lerobot.common.datasets.utils import (
hf_transform_to_torch,
reset_episode_index,
)
from lerobot.common.utils.utils import seeded_context, set_global_seed
@pytest.mark.parametrize(
"rand_fn",
(
[
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
+ [lambda: torch.rand(1, device="cuda")]
if torch.cuda.is_available()
else []
),
from lerobot.common.utils.utils import (
get_global_random_state,
seeded_context,
set_global_random_state,
set_global_seed,
)
# Random generation functions for testing the seeding and random state get/set.
rand_fns = [
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
if torch.cuda.is_available():
rand_fns.append(lambda: torch.rand(1, device="cuda"))
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
@@ -46,6 +48,15 @@ def test_seeding(rand_fn: Callable[[], int]):
assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index():
dataset = Dataset.from_dict(
{