forked from tangger/lerobot
Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
@@ -1,8 +1,3 @@
|
||||
import random
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -10,50 +5,6 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
seeded_context,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
# Random generation functions for testing the seeding and random state get/set.
|
||||
rand_fns = [
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
rand_fns.append(lambda: torch.rand(1, device="cuda"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rand_fn", rand_fns)
|
||||
def test_seeding(rand_fn: Callable[[], int]):
|
||||
set_global_seed(0)
|
||||
a = rand_fn()
|
||||
with seeded_context(1337):
|
||||
c = rand_fn()
|
||||
b = rand_fn()
|
||||
set_global_seed(0)
|
||||
a_ = rand_fn()
|
||||
b_ = rand_fn()
|
||||
# Check that `set_global_seed` lets us reproduce a and b.
|
||||
assert a_ == a
|
||||
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
|
||||
assert b_ == b
|
||||
set_global_seed(1337)
|
||||
c_ = rand_fn()
|
||||
# Check that `seeded_context` and `global_seed` give the same reproducibility.
|
||||
assert c_ == c
|
||||
|
||||
|
||||
def test_get_set_random_state():
|
||||
"""Check that getting the random state, then setting it results in the same random number generation."""
|
||||
random_state_dict = get_global_random_state()
|
||||
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
||||
set_global_random_state(random_state_dict)
|
||||
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
||||
assert rand_numbers_ == rand_numbers
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
|
||||
Reference in New Issue
Block a user