Remove offline training, refactor train.py and logging/checkpointing (#670)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2025-02-11 10:36:06 +01:00
committed by GitHub
parent 334deb985d
commit 90e099b39f
40 changed files with 1515 additions and 935 deletions

View File

@@ -1,8 +1,3 @@
import random
from typing import Callable
import numpy as np
import pytest
import torch
from datasets import Dataset
@@ -10,50 +5,6 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_
from lerobot.common.datasets.utils import (
hf_transform_to_torch,
)
from lerobot.common.utils.utils import (
get_global_random_state,
seeded_context,
set_global_random_state,
set_global_seed,
)
# Random generation functions for testing the seeding and random state get/set.
rand_fns = [
random.random,
np.random.random,
lambda: torch.rand(1).item(),
]
if torch.cuda.is_available():
rand_fns.append(lambda: torch.rand(1, device="cuda"))
@pytest.mark.parametrize("rand_fn", rand_fns)
def test_seeding(rand_fn: Callable[[], int]):
set_global_seed(0)
a = rand_fn()
with seeded_context(1337):
c = rand_fn()
b = rand_fn()
set_global_seed(0)
a_ = rand_fn()
b_ = rand_fn()
# Check that `set_global_seed` lets us reproduce a and b.
assert a_ == a
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
assert b_ == b
set_global_seed(1337)
c_ = rand_fn()
# Check that `seeded_context` and `global_seed` give the same reproducibility.
assert c_ == c
def test_get_set_random_state():
"""Check that getting the random state, then setting it results in the same random number generation."""
random_state_dict = get_global_random_state()
rand_numbers = [rand_fn() for rand_fn in rand_fns]
set_global_random_state(random_state_dict)
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
assert rand_numbers_ == rand_numbers
def test_calculate_episode_data_index():