Remove offline training, refactor train.py and logging/checkpointing (#670)
Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
@@ -28,6 +28,7 @@ pytest_plugins = [
|
||||
"tests.fixtures.dataset_factories",
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
"tests.fixtures.optimizers",
|
||||
]
|
||||
|
||||
|
||||
|
||||
26
tests/fixtures/optimizers.py
vendored
Normal file
26
tests/fixtures/optimizers.py
vendored
Normal file
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.optim.optimizers import AdamConfig
|
||||
from lerobot.common.optim.schedulers import VQBeTSchedulerConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_params():
|
||||
return [torch.nn.Parameter(torch.randn(10, 10))]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer(model_params):
|
||||
optimizer = AdamConfig().build(model_params)
|
||||
# Dummy step to populate state
|
||||
loss = sum(param.sum() for param in model_params)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
return config.build(optimizer, num_training_steps=100)
|
||||
@@ -25,7 +25,7 @@ from lerobot.common.datasets.transforms import (
|
||||
ImageTransformsConfig,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
|
||||
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
@@ -22,14 +22,14 @@ from safetensors.torch import save_file
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.common.policies.factory import make_policy, make_policy_config
|
||||
from lerobot.common.utils.utils import set_global_seed
|
||||
from lerobot.common.utils.random_utils import set_seed
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwargs):
|
||||
# TODO(rcadene, aliberts): env_name?
|
||||
set_global_seed(1337)
|
||||
set_seed(1337)
|
||||
|
||||
train_cfg = TrainPipelineConfig(
|
||||
# TODO(rcadene, aliberts): remove dataset download
|
||||
@@ -53,9 +53,9 @@ def get_policy_stats(ds_repo_id, env_name, policy_name, policy_kwargs, train_kwa
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
output_dict = policy.forward(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
|
||||
loss = output_dict["loss"]
|
||||
output_dict["loss"] = loss
|
||||
|
||||
loss.backward()
|
||||
grad_stats = {}
|
||||
|
||||
@@ -29,7 +29,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
@@ -38,9 +37,7 @@ from lerobot.common.robot_devices.control_configs import (
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.scripts.control_robot import calibrate, record, replay, teleoperate
|
||||
from tests.test_robots import make_robot
|
||||
from tests.utils import DEVICE, TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
|
||||
@@ -185,20 +182,8 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
|
||||
out_dir = tmpdir / "logger"
|
||||
|
||||
ds_cfg = DatasetConfig(repo_id, local_files_only=True)
|
||||
train_cfg = TrainPipelineConfig(
|
||||
dataset=ds_cfg,
|
||||
policy=policy_cfg,
|
||||
output_dir=out_dir,
|
||||
device=DEVICE,
|
||||
)
|
||||
logger = Logger(train_cfg)
|
||||
logger.save_checkpoint(
|
||||
train_step=0,
|
||||
identifier=0,
|
||||
policy=policy,
|
||||
)
|
||||
pretrained_policy_path = out_dir / "checkpoints/last/pretrained_model"
|
||||
policy.save_pretrained(pretrained_policy_path)
|
||||
|
||||
# In `examples/9_use_aloha.md`, we advise using `num_image_writer_processes=1`
|
||||
# during inference, to reach constent fps, so we test this here.
|
||||
|
||||
@@ -46,7 +46,7 @@ from lerobot.common.datasets.utils import (
|
||||
from lerobot.common.envs.factory import make_env_config
|
||||
from lerobot.common.policies.factory import make_policy_config
|
||||
from lerobot.common.robot_devices.robots.utils import make_robot
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
|
||||
@@ -28,7 +28,7 @@ from lerobot.common.datasets.transforms import (
|
||||
SharpnessJitter,
|
||||
make_transform_from_config,
|
||||
)
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import (
|
||||
save_all_transforms,
|
||||
save_each_transform,
|
||||
|
||||
74
tests/test_io_utils.py
Normal file
74
tests/test_io_utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.io_utils import deserialize_json_into_object
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_json_file(tmp_path: Path):
|
||||
"""Writes `data` to a temporary JSON file and returns the file's path."""
|
||||
|
||||
def _write(data: Any) -> Path:
|
||||
file_path = tmp_path / "data.json"
|
||||
with file_path.open("w", encoding="utf-8") as f:
|
||||
json.dump(data, f)
|
||||
return file_path
|
||||
|
||||
return _write
|
||||
|
||||
|
||||
def test_simple_dict(tmp_json_file):
|
||||
data = {"name": "Alice", "age": 30}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"name": "", "age": 0}
|
||||
assert deserialize_json_into_object(json_path, obj) == data
|
||||
|
||||
|
||||
def test_nested_structure(tmp_json_file):
|
||||
data = {"items": [1, 2, 3], "info": {"active": True}}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"items": [0, 0, 0], "info": {"active": False}}
|
||||
assert deserialize_json_into_object(json_path, obj) == data
|
||||
|
||||
|
||||
def test_tuple_conversion(tmp_json_file):
|
||||
data = {"coords": [10.5, 20.5]}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"coords": (0.0, 0.0)}
|
||||
result = deserialize_json_into_object(json_path, obj)
|
||||
assert result["coords"] == (10.5, 20.5)
|
||||
|
||||
|
||||
def test_type_mismatch_raises(tmp_json_file):
|
||||
data = {"numbers": {"bad": "structure"}}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"numbers": [0, 0]}
|
||||
with pytest.raises(TypeError):
|
||||
deserialize_json_into_object(json_path, obj)
|
||||
|
||||
|
||||
def test_missing_key_raises(tmp_json_file):
|
||||
data = {"one": 1}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"one": 0, "two": 0}
|
||||
with pytest.raises(ValueError):
|
||||
deserialize_json_into_object(json_path, obj)
|
||||
|
||||
|
||||
def test_extra_key_raises(tmp_json_file):
|
||||
data = {"one": 1, "two": 2}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"one": 0}
|
||||
with pytest.raises(ValueError):
|
||||
deserialize_json_into_object(json_path, obj)
|
||||
|
||||
|
||||
def test_list_length_mismatch_raises(tmp_json_file):
|
||||
data = {"nums": [1, 2, 3]}
|
||||
json_path = tmp_json_file(data)
|
||||
obj = {"nums": [0, 0]}
|
||||
with pytest.raises(ValueError):
|
||||
deserialize_json_into_object(json_path, obj)
|
||||
107
tests/test_logging_utils.py
Normal file
107
tests/test_logging_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import pytest
|
||||
|
||||
from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_metrics():
|
||||
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
|
||||
|
||||
|
||||
def test_average_meter_initialization():
|
||||
meter = AverageMeter("loss", ":.2f")
|
||||
assert meter.name == "loss"
|
||||
assert meter.fmt == ":.2f"
|
||||
assert meter.val == 0.0
|
||||
assert meter.avg == 0.0
|
||||
assert meter.sum == 0.0
|
||||
assert meter.count == 0.0
|
||||
|
||||
|
||||
def test_average_meter_update():
|
||||
meter = AverageMeter("accuracy")
|
||||
meter.update(5, n=2)
|
||||
assert meter.val == 5
|
||||
assert meter.sum == 10
|
||||
assert meter.count == 2
|
||||
assert meter.avg == 5
|
||||
|
||||
|
||||
def test_average_meter_reset():
|
||||
meter = AverageMeter("loss")
|
||||
meter.update(3, 4)
|
||||
meter.reset()
|
||||
assert meter.val == 0.0
|
||||
assert meter.avg == 0.0
|
||||
assert meter.sum == 0.0
|
||||
assert meter.count == 0.0
|
||||
|
||||
|
||||
def test_average_meter_str():
|
||||
meter = AverageMeter("metric", ":.1f")
|
||||
meter.update(4.567, 3)
|
||||
assert str(meter) == "metric:4.6"
|
||||
|
||||
|
||||
def test_metrics_tracker_initialization(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
|
||||
)
|
||||
assert tracker.steps == 10
|
||||
assert tracker.samples == 10 * 32
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
assert "loss" in tracker.metrics
|
||||
assert "accuracy" in tracker.metrics
|
||||
|
||||
|
||||
def test_metrics_tracker_step(mock_metrics):
|
||||
tracker = MetricsTracker(
|
||||
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
|
||||
)
|
||||
tracker.step()
|
||||
assert tracker.steps == 6
|
||||
assert tracker.samples == 6 * 32
|
||||
assert tracker.episodes == tracker.samples / (1000 / 50)
|
||||
assert tracker.epochs == tracker.samples / 1000
|
||||
|
||||
|
||||
def test_metrics_tracker_getattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
assert tracker.loss == mock_metrics["loss"]
|
||||
assert tracker.accuracy == mock_metrics["accuracy"]
|
||||
with pytest.raises(AttributeError):
|
||||
_ = tracker.non_existent_metric
|
||||
|
||||
|
||||
def test_metrics_tracker_setattr(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss = 2.0
|
||||
assert tracker.loss.val == 2.0
|
||||
|
||||
|
||||
def test_metrics_tracker_str(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss.update(3.456, 1)
|
||||
tracker.accuracy.update(0.876, 1)
|
||||
output = str(tracker)
|
||||
assert "loss:3.456" in output
|
||||
assert "accuracy:0.88" in output
|
||||
|
||||
|
||||
def test_metrics_tracker_to_dict(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss.update(5, 2)
|
||||
metrics_dict = tracker.to_dict()
|
||||
assert isinstance(metrics_dict, dict)
|
||||
assert metrics_dict["loss"] == 5 # average value
|
||||
assert metrics_dict["steps"] == tracker.steps
|
||||
|
||||
|
||||
def test_metrics_tracker_reset_averages(mock_metrics):
|
||||
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
|
||||
tracker.loss.update(10, 3)
|
||||
tracker.accuracy.update(0.95, 5)
|
||||
tracker.reset_averages()
|
||||
assert tracker.loss.avg == 0.0
|
||||
assert tracker.accuracy.avg == 0.0
|
||||
43
tests/test_optimizers.py
Normal file
43
tests/test_optimizers.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.common.optim.optimizers import (
|
||||
AdamConfig,
|
||||
AdamWConfig,
|
||||
SGDConfig,
|
||||
load_optimizer_state,
|
||||
save_optimizer_state,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls, expected_class",
|
||||
[
|
||||
(AdamConfig, torch.optim.Adam),
|
||||
(AdamWConfig, torch.optim.AdamW),
|
||||
(SGDConfig, torch.optim.SGD),
|
||||
],
|
||||
)
|
||||
def test_optimizer_build(config_cls, expected_class, model_params):
|
||||
config = config_cls()
|
||||
optimizer = config.build(model_params)
|
||||
assert isinstance(optimizer, expected_class)
|
||||
assert optimizer.defaults["lr"] == config.lr
|
||||
|
||||
|
||||
def test_save_optimizer_state(optimizer, tmp_path):
|
||||
save_optimizer_state(optimizer, tmp_path)
|
||||
assert (tmp_path / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
|
||||
|
||||
def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||
save_optimizer_state(optimizer, tmp_path)
|
||||
loaded_optimizer = AdamConfig().build(model_params)
|
||||
loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path)
|
||||
|
||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||
@@ -36,7 +36,7 @@ from lerobot.common.policies.factory import (
|
||||
)
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.common.utils.utils import seeded_context
|
||||
from lerobot.common.utils.random_utils import seeded_context
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
|
||||
109
tests/test_random_utils.py
Normal file
109
tests/test_random_utils.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.utils.random_utils import (
|
||||
deserialize_numpy_rng_state,
|
||||
deserialize_python_rng_state,
|
||||
deserialize_rng_state,
|
||||
deserialize_torch_rng_state,
|
||||
get_rng_state,
|
||||
seeded_context,
|
||||
serialize_numpy_rng_state,
|
||||
serialize_python_rng_state,
|
||||
serialize_rng_state,
|
||||
serialize_torch_rng_state,
|
||||
set_rng_state,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixed_seed():
|
||||
"""Fixture to set a consistent initial seed for each test."""
|
||||
set_seed(12345)
|
||||
yield
|
||||
|
||||
|
||||
def test_serialize_deserialize_python_rng(fixed_seed):
|
||||
# Save state after generating val1
|
||||
_ = random.random()
|
||||
st = serialize_python_rng_state()
|
||||
# Next random is val2
|
||||
val2 = random.random()
|
||||
# Restore the state, so the next random should match val2
|
||||
deserialize_python_rng_state(st)
|
||||
val3 = random.random()
|
||||
assert val2 == val3
|
||||
|
||||
|
||||
def test_serialize_deserialize_numpy_rng(fixed_seed):
|
||||
_ = np.random.rand()
|
||||
st = serialize_numpy_rng_state()
|
||||
val2 = np.random.rand()
|
||||
deserialize_numpy_rng_state(st)
|
||||
val3 = np.random.rand()
|
||||
assert val2 == val3
|
||||
|
||||
|
||||
def test_serialize_deserialize_torch_rng(fixed_seed):
|
||||
_ = torch.rand(1).item()
|
||||
st = serialize_torch_rng_state()
|
||||
val2 = torch.rand(1).item()
|
||||
deserialize_torch_rng_state(st)
|
||||
val3 = torch.rand(1).item()
|
||||
assert val2 == val3
|
||||
|
||||
|
||||
def test_serialize_deserialize_rng(fixed_seed):
|
||||
# Generate one from each library
|
||||
_ = random.random()
|
||||
_ = np.random.rand()
|
||||
_ = torch.rand(1).item()
|
||||
# Serialize
|
||||
st = serialize_rng_state()
|
||||
# Generate second set
|
||||
val_py2 = random.random()
|
||||
val_np2 = np.random.rand()
|
||||
val_th2 = torch.rand(1).item()
|
||||
# Restore, so the next draws should match val_py2, val_np2, val_th2
|
||||
deserialize_rng_state(st)
|
||||
assert random.random() == val_py2
|
||||
assert np.random.rand() == val_np2
|
||||
assert torch.rand(1).item() == val_th2
|
||||
|
||||
|
||||
def test_get_set_rng_state(fixed_seed):
|
||||
st = get_rng_state()
|
||||
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
# Change states
|
||||
random.random()
|
||||
np.random.rand()
|
||||
torch.rand(1)
|
||||
# Restore
|
||||
set_rng_state(st)
|
||||
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
assert val1 == val2
|
||||
|
||||
|
||||
def test_set_seed():
|
||||
set_seed(1337)
|
||||
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
set_seed(1337)
|
||||
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
assert val1 == val2
|
||||
|
||||
|
||||
def test_seeded_context(fixed_seed):
|
||||
val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
with seeded_context(1337):
|
||||
seeded_val1 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
with seeded_context(1337):
|
||||
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
|
||||
|
||||
assert seeded_val1 == seeded_val2
|
||||
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
|
||||
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting
|
||||
81
tests/test_schedulers.py
Normal file
81
tests/test_schedulers.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
from lerobot.common.constants import SCHEDULER_STATE
|
||||
from lerobot.common.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
DiffuserSchedulerConfig,
|
||||
VQBeTSchedulerConfig,
|
||||
load_scheduler_state,
|
||||
save_scheduler_state,
|
||||
)
|
||||
|
||||
|
||||
def test_diffuser_scheduler(optimizer):
|
||||
config = DiffuserSchedulerConfig(name="cosine", num_warmup_steps=5)
|
||||
scheduler = config.build(optimizer, num_training_steps=100)
|
||||
assert isinstance(scheduler, LambdaLR)
|
||||
|
||||
optimizer.step() # so that we don't get torch warning
|
||||
scheduler.step()
|
||||
expected_state_dict = {
|
||||
"_get_lr_called_within_step": False,
|
||||
"_last_lr": [0.0002],
|
||||
"_step_count": 2,
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
def test_vqbet_scheduler(optimizer):
|
||||
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
|
||||
scheduler = config.build(optimizer, num_training_steps=100)
|
||||
assert isinstance(scheduler, LambdaLR)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
expected_state_dict = {
|
||||
"_get_lr_called_within_step": False,
|
||||
"_last_lr": [0.001],
|
||||
"_step_count": 2,
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
def test_cosine_decay_with_warmup_scheduler(optimizer):
|
||||
config = CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=10, num_decay_steps=90, peak_lr=0.01, decay_lr=0.001
|
||||
)
|
||||
scheduler = config.build(optimizer, num_training_steps=100)
|
||||
assert isinstance(scheduler, LambdaLR)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
expected_state_dict = {
|
||||
"_get_lr_called_within_step": False,
|
||||
"_last_lr": [0.0001818181818181819],
|
||||
"_step_count": 2,
|
||||
"base_lrs": [0.001],
|
||||
"last_epoch": 1,
|
||||
"lr_lambdas": [None],
|
||||
"verbose": False,
|
||||
}
|
||||
assert scheduler.state_dict() == expected_state_dict
|
||||
|
||||
|
||||
def test_save_scheduler_state(scheduler, tmp_path):
|
||||
save_scheduler_state(scheduler, tmp_path)
|
||||
assert (tmp_path / SCHEDULER_STATE).is_file()
|
||||
|
||||
|
||||
def test_save_load_scheduler_state(scheduler, tmp_path):
|
||||
save_scheduler_state(scheduler, tmp_path)
|
||||
loaded_scheduler = load_scheduler_state(scheduler, tmp_path)
|
||||
|
||||
assert scheduler.state_dict() == loaded_scheduler.state_dict()
|
||||
84
tests/test_train_utils.py
Normal file
84
tests/test_train_utils.py
Normal file
@@ -0,0 +1,84 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from lerobot.common.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
RNG_STATE,
|
||||
SCHEDULER_STATE,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.common.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def test_get_step_identifier():
|
||||
assert get_step_identifier(5, 1000) == "000005"
|
||||
assert get_step_identifier(123, 100_000) == "000123"
|
||||
assert get_step_identifier(456789, 1_000_000) == "0456789"
|
||||
|
||||
|
||||
def test_get_step_checkpoint_dir():
|
||||
output_dir = Path("/checkpoints")
|
||||
step_dir = get_step_checkpoint_dir(output_dir, 1000, 5)
|
||||
assert step_dir == output_dir / CHECKPOINTS_DIR / "000005"
|
||||
|
||||
|
||||
def test_save_load_training_step(tmp_path):
|
||||
save_training_step(5000, tmp_path)
|
||||
assert (tmp_path / TRAINING_STEP).is_file()
|
||||
|
||||
|
||||
def test_load_training_step(tmp_path):
|
||||
step = 5000
|
||||
save_training_step(step, tmp_path)
|
||||
loaded_step = load_training_step(tmp_path)
|
||||
assert loaded_step == step
|
||||
|
||||
|
||||
def test_update_last_checkpoint(tmp_path):
|
||||
checkpoint = tmp_path / "0005"
|
||||
checkpoint.mkdir()
|
||||
update_last_checkpoint(checkpoint)
|
||||
last_checkpoint = tmp_path / LAST_CHECKPOINT_LINK
|
||||
assert last_checkpoint.is_symlink()
|
||||
assert last_checkpoint.resolve() == checkpoint
|
||||
|
||||
|
||||
@patch("lerobot.common.utils.train_utils.save_training_state")
|
||||
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
cfg = Mock()
|
||||
save_checkpoint(tmp_path, 10, cfg, policy, optimizer)
|
||||
policy.save_pretrained.assert_called_once()
|
||||
cfg.save_pretrained.assert_called_once()
|
||||
mock_save_training_state.assert_called_once()
|
||||
|
||||
|
||||
def test_save_training_state(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
assert (tmp_path / TRAINING_STATE_DIR).is_dir()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / TRAINING_STEP).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / RNG_STATE).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
assert (tmp_path / TRAINING_STATE_DIR / SCHEDULER_STATE).is_file()
|
||||
|
||||
|
||||
def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
@@ -1,8 +1,3 @@
|
||||
import random
|
||||
from typing import Callable
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
@@ -10,50 +5,6 @@ from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_
|
||||
from lerobot.common.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
seeded_context,
|
||||
set_global_random_state,
|
||||
set_global_seed,
|
||||
)
|
||||
|
||||
# Random generation functions for testing the seeding and random state get/set.
|
||||
rand_fns = [
|
||||
random.random,
|
||||
np.random.random,
|
||||
lambda: torch.rand(1).item(),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
rand_fns.append(lambda: torch.rand(1, device="cuda"))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("rand_fn", rand_fns)
|
||||
def test_seeding(rand_fn: Callable[[], int]):
|
||||
set_global_seed(0)
|
||||
a = rand_fn()
|
||||
with seeded_context(1337):
|
||||
c = rand_fn()
|
||||
b = rand_fn()
|
||||
set_global_seed(0)
|
||||
a_ = rand_fn()
|
||||
b_ = rand_fn()
|
||||
# Check that `set_global_seed` lets us reproduce a and b.
|
||||
assert a_ == a
|
||||
# Additionally, check that the `seeded_context` didn't interrupt the global RNG.
|
||||
assert b_ == b
|
||||
set_global_seed(1337)
|
||||
c_ = rand_fn()
|
||||
# Check that `seeded_context` and `global_seed` give the same reproducibility.
|
||||
assert c_ == c
|
||||
|
||||
|
||||
def test_get_set_random_state():
|
||||
"""Check that getting the random state, then setting it results in the same random number generation."""
|
||||
random_state_dict = get_global_random_state()
|
||||
rand_numbers = [rand_fn() for rand_fn in rand_fns]
|
||||
set_global_random_state(random_state_dict)
|
||||
rand_numbers_ = [rand_fn() for rand_fn in rand_fns]
|
||||
assert rand_numbers_ == rand_numbers
|
||||
|
||||
|
||||
def test_calculate_episode_data_index():
|
||||
|
||||
Reference in New Issue
Block a user