[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -51,7 +51,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
batch = next(iter(dataloader))
loss, output_dict = policy.forward(batch)
if output_dict is not None:
output_dict = {k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)}
output_dict = {
k: v for k, v in output_dict.items() if isinstance(v, torch.Tensor)
}
output_dict["loss"] = loss
else:
output_dict = {"loss": loss}
@@ -69,7 +71,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = (
param.std() if param.numel() > 1 else torch.tensor(float(0.0))
)
optimizer.zero_grad()
policy.reset()
@@ -96,11 +100,15 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
else:
actions_queue = train_cfg.policy.n_action_repeats
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
actions = {
str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)
}
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict):
def save_policy_to_safetensors(
output_dir: Path, ds_repo_id: str, policy_name: str, policy_kwargs: dict
):
if output_dir.exists():
print(f"Overwrite existing safetensors in '{output_dir}':")
print(f" - Validate with: `git add {output_dir}`")
@@ -108,7 +116,9 @@ def save_policy_to_safetensors(output_dir: Path, ds_repo_id: str, policy_name: s
shutil.rmtree(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, policy_name, policy_kwargs
)
save_file(output_dict, output_dir / "output_dict.safetensors")
save_file(grad_stats, output_dir / "grad_stats.safetensors")
save_file(param_stats, output_dir / "param_stats.safetensors")
@@ -141,5 +151,7 @@ if __name__ == "__main__":
raise RuntimeError("No policies were provided!")
for ds_repo_id, policy, policy_kwargs, file_name_extra in artifacts_cfg:
ds_name = ds_repo_id.split("/")[-1]
output_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
output_dir = (
Path("tests/artifacts/policies") / f"{ds_name}_{policy}_{file_name_extra}"
)
save_policy_to_safetensors(output_dir, ds_repo_id, policy, policy_kwargs)

View File

@@ -226,7 +226,13 @@ def test_save_images_from_cameras(tmp_path, request, camera_type, mock):
@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES)
@require_camera
def test_camera_rotation(request, camera_type, mock):
config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30}
config_kwargs = {
"camera_type": camera_type,
"mock": mock,
"width": 640,
"height": 480,
"fps": 30,
}
# No rotation.
camera = make_camera(**config_kwargs, rotation=None)

View File

@@ -9,7 +9,9 @@ from lerobot.common.envs.configs import EnvConfig
from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap
def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str:
def create_plugin_code(
*, base_class: str = "EnvConfig", plugin_name: str = "test_env"
) -> str:
"""Creates a dummy plugin module that implements its own EnvConfig subclass."""
return f"""
from dataclasses import dataclass

View File

@@ -31,7 +31,11 @@ from lerobot.common.datasets.compute_stats import (
def mock_load_image_as_numpy(path, dtype, channel_first):
return np.ones((3, 32, 32), dtype=dtype) if channel_first else np.ones((32, 32, 3), dtype=dtype)
return (
np.ones((3, 32, 32), dtype=dtype)
if channel_first
else np.ones((32, 32, 3), dtype=dtype)
)
@pytest.fixture
@@ -61,7 +65,10 @@ def test_sample_indices():
assert len(indices) == estimate_num_samples(10)
@patch("lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy)
@patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
)
def test_sample_images(mock_load):
image_paths = [f"image_{i}.jpg" for i in range(100)]
images = sample_images(image_paths)
@@ -74,9 +81,20 @@ def test_sample_images(mock_load):
def test_get_feature_stats_images():
data = np.random.rand(100, 3, 32, 32)
stats = get_feature_stats(data, axis=(0, 2, 3), keepdims=True)
assert "min" in stats and "max" in stats and "mean" in stats and "std" in stats and "count" in stats
assert (
"min" in stats
and "max" in stats
and "mean" in stats
and "std" in stats
and "count" in stats
)
np.testing.assert_equal(stats["count"], np.array([100]))
assert stats["min"].shape == stats["max"].shape == stats["mean"].shape == stats["std"].shape
assert (
stats["min"].shape
== stats["max"].shape
== stats["mean"].shape
== stats["std"].shape
)
def test_get_feature_stats_axis_0_keepdims(sample_array):
@@ -145,7 +163,8 @@ def test_compute_episode_stats():
}
with patch(
"lerobot.common.datasets.compute_stats.load_image_as_numpy", side_effect=mock_load_image_as_numpy
"lerobot.common.datasets.compute_stats.load_image_as_numpy",
side_effect=mock_load_image_as_numpy,
):
stats = compute_episode_stats(episode_data, features)
@@ -233,7 +252,13 @@ def test_aggregate_stats():
"std": [2.87, 5.87, 8.87],
"count": 10,
},
"observation.state": {"min": 1, "max": 10, "mean": 5.5, "std": 2.87, "count": 10},
"observation.state": {
"min": 1,
"max": 10,
"mean": 5.5,
"std": 2.87,
"count": 10,
},
"extra_key_0": {"min": 5, "max": 25, "mean": 15, "std": 6, "count": 6},
},
{
@@ -244,7 +269,13 @@ def test_aggregate_stats():
"std": [3.42, 2.42, 1.42],
"count": 15,
},
"observation.state": {"min": 2, "max": 15, "mean": 8.5, "std": 3.42, "count": 15},
"observation.state": {
"min": 2,
"max": 15,
"mean": 8.5,
"std": 3.42,
"count": 15,
},
"extra_key_1": {"min": 0, "max": 20, "mean": 10, "std": 5, "count": 5},
},
]
@@ -284,28 +315,47 @@ def test_aggregate_stats():
for ep_stats in all_stats:
for fkey, stats in ep_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
stats[k] = np.array(
stats[k], dtype=np.int64 if k == "count" else np.float32
)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
stats[k] = stats[k].reshape(
3, 1, 1
) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
# cast to numpy
for fkey, stats in expected_agg_stats.items():
for k in stats:
stats[k] = np.array(stats[k], dtype=np.int64 if k == "count" else np.float32)
stats[k] = np.array(
stats[k], dtype=np.int64 if k == "count" else np.float32
)
if fkey == "observation.image" and k != "count":
stats[k] = stats[k].reshape(3, 1, 1) # for normalization on image channels
stats[k] = stats[k].reshape(
3, 1, 1
) # for normalization on image channels
else:
stats[k] = stats[k].reshape(1)
results = aggregate_stats(all_stats)
for fkey in expected_agg_stats:
np.testing.assert_allclose(results[fkey]["min"], expected_agg_stats[fkey]["min"])
np.testing.assert_allclose(results[fkey]["max"], expected_agg_stats[fkey]["max"])
np.testing.assert_allclose(results[fkey]["mean"], expected_agg_stats[fkey]["mean"])
np.testing.assert_allclose(
results[fkey]["std"], expected_agg_stats[fkey]["std"], atol=1e-04, rtol=1e-04
results[fkey]["min"], expected_agg_stats[fkey]["min"]
)
np.testing.assert_allclose(
results[fkey]["max"], expected_agg_stats[fkey]["max"]
)
np.testing.assert_allclose(
results[fkey]["mean"], expected_agg_stats[fkey]["mean"]
)
np.testing.assert_allclose(
results[fkey]["std"],
expected_agg_stats[fkey]["std"],
atol=1e-04,
rtol=1e-04,
)
np.testing.assert_allclose(
results[fkey]["count"], expected_agg_stats[fkey]["count"]
)
np.testing.assert_allclose(results[fkey]["count"], expected_agg_stats[fkey]["count"])

View File

@@ -72,7 +72,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
dataset_create = LeRobotDataset.create(
repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create
)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
@@ -104,7 +106,8 @@ def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n"
ValueError,
match="Feature mismatch in `frame` dictionary:\nMissing features: {'task'}\n",
):
dataset.add_frame({"state": torch.randn(1)})
@@ -113,7 +116,8 @@ def test_add_frame_missing_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n"
ValueError,
match="Feature mismatch in `frame` dictionary:\nMissing features: {'state'}\n",
):
dataset.add_frame({"task": "Dummy task"})
@@ -122,18 +126,24 @@ def test_add_frame_extra_feature(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n"
ValueError,
match="Feature mismatch in `frame` dictionary:\nExtra features: {'extra'}\n",
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"})
dataset.add_frame(
{"state": torch.randn(1), "task": "Dummy task", "extra": "dummy_extra"}
)
def test_add_frame_wrong_type(tmp_path, empty_lerobot_dataset_factory):
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError, match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n"
ValueError,
match="The feature 'state' of dtype 'float16' is not of the expected dtype 'float32'.\n",
):
dataset.add_frame({"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"})
dataset.add_frame(
{"state": torch.randn(1, dtype=torch.float16), "task": "Dummy task"}
)
def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
@@ -141,7 +151,9 @@ def test_add_frame_wrong_shape(tmp_path, empty_lerobot_dataset_factory):
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"),
match=re.escape(
"The feature 'state' of shape '(1,)' does not have the expected shape '(2,)'.\n"
),
):
dataset.add_frame({"state": torch.randn(1), "task": "Dummy task"})
@@ -163,7 +175,9 @@ def test_add_frame_wrong_shape_torch_ndim_0(tmp_path, empty_lerobot_dataset_fact
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
with pytest.raises(
ValueError,
match=re.escape("The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"),
match=re.escape(
"The feature 'state' of shape '()' does not have the expected shape '(1,)'.\n"
),
):
dataset.add_frame({"state": torch.tensor(1.0), "task": "Dummy task"})
@@ -457,7 +471,9 @@ def test_flatten_unflatten_dict():
d = unflatten_dict(flatten_dict(d))
# test equality between nested dicts
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), (
f"{original_d} != {d}"
)
@pytest.mark.parametrize(
@@ -511,7 +527,13 @@ def test_backward_compatibility(repo_id):
load_and_compare(i + 1)
# test 2 frames at the middle of first episode
i = int((dataset.episode_data_index["to"][0].item() - dataset.episode_data_index["from"][0].item()) / 2)
i = int(
(
dataset.episode_data_index["to"][0].item()
- dataset.episode_data_index["from"][0].item()
)
/ 2
)
load_and_compare(i)
load_and_compare(i + 1)

View File

@@ -54,7 +54,9 @@ def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, np.n
@pytest.fixture(scope="module")
def synced_timestamps_factory(hf_dataset_factory):
def _create_synced_timestamps(fps: int = 30) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def _create_synced_timestamps(
fps: int = 30,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
hf_dataset = hf_dataset_factory(fps=fps)
timestamps = torch.stack(hf_dataset["timestamp"]).numpy()
episode_indices = torch.stack(hf_dataset["episode_index"]).numpy()
@@ -69,8 +71,12 @@ def unsynced_timestamps_factory(synced_timestamps_factory):
def _create_unsynced_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 1.1 # Modify a single timestamp just outside tolerance
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
fps=fps
)
timestamps[30] += (
tolerance_s * 1.1
) # Modify a single timestamp just outside tolerance
return timestamps, episode_indices, episode_data_index
return _create_unsynced_timestamps
@@ -81,8 +87,12 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
def _create_slightly_off_timestamps(
fps: int = 30, tolerance_s: float = 1e-4
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(fps=fps)
timestamps[30] += tolerance_s * 0.9 # Modify a single timestamp just inside tolerance
timestamps, episode_indices, episode_data_index = synced_timestamps_factory(
fps=fps
)
timestamps[30] += (
tolerance_s * 0.9
) # Modify a single timestamp just inside tolerance
return timestamps, episode_indices, episode_data_index
return _create_slightly_off_timestamps
@@ -91,9 +101,13 @@ def slightly_off_timestamps_factory(synced_timestamps_factory):
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(
fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
fps: int = 30,
keys: list = DUMMY_MOTOR_FEATURES,
min_max_range: tuple[int, int] = (-10, 10),
) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys}
delta_timestamps = {
key: [i * (1 / fps) for i in range(*min_max_range)] for key in keys
}
return delta_timestamps
return _create_valid_delta_timestamps
@@ -130,7 +144,9 @@ def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
@pytest.fixture(scope="module")
def delta_indices_factory():
def _delta_indices(keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)) -> dict:
def _delta_indices(
keys: list = DUMMY_MOTOR_FEATURES, min_max_range: tuple[int, int] = (-10, 10)
) -> dict:
return {key: list(range(*min_max_range)) for key in keys}
return _delta_indices
@@ -182,7 +198,9 @@ def test_check_timestamps_sync_unsynced_no_exception(unsynced_timestamps_factory
def test_check_timestamps_sync_slightly_off(slightly_off_timestamps_factory):
fps = 30
tolerance_s = 1e-4
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(fps, tolerance_s)
timestamps, ep_idx, ep_data_index = slightly_off_timestamps_factory(
fps, tolerance_s
)
result = check_timestamps_sync(
timestamps=timestamps,
episode_indices=ep_idx,
@@ -223,7 +241,9 @@ def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(
fps, tolerance_s
)
result = check_delta_timestamps(
delta_timestamps=slightly_off_delta_timestamps,
fps=fps,

View File

@@ -33,7 +33,9 @@ from lerobot.scripts.visualize_image_transforms import (
save_all_transforms,
save_each_transform,
)
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR
from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import (
ARTIFACT_DIR,
)
from tests.utils import require_x86_64_kernel
@@ -80,7 +82,11 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"brightness": ImageTransformConfig(type="ColorJitter", kwargs={"brightness": min_max})},
tfs={
"brightness": ImageTransformConfig(
type="ColorJitter", kwargs={"brightness": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(brightness=min_max)
@@ -91,7 +97,12 @@ def test_get_image_transforms_brightness(img_tensor_factory, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True, tfs={"contrast": ImageTransformConfig(type="ColorJitter", kwargs={"contrast": min_max})}
enable=True,
tfs={
"contrast": ImageTransformConfig(
type="ColorJitter", kwargs={"contrast": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(contrast=min_max)
@@ -103,7 +114,11 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"saturation": ImageTransformConfig(type="ColorJitter", kwargs={"saturation": min_max})},
tfs={
"saturation": ImageTransformConfig(
type="ColorJitter", kwargs={"saturation": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(saturation=min_max)
@@ -114,7 +129,8 @@ def test_get_image_transforms_saturation(img_tensor_factory, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True, tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})}
enable=True,
tfs={"hue": ImageTransformConfig(type="ColorJitter", kwargs={"hue": min_max})},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = v2.ColorJitter(hue=min_max)
@@ -126,7 +142,11 @@ def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_cfg = ImageTransformsConfig(
enable=True,
tfs={"sharpness": ImageTransformConfig(type="SharpnessJitter", kwargs={"sharpness": min_max})},
tfs={
"sharpness": ImageTransformConfig(
type="SharpnessJitter", kwargs={"sharpness": min_max}
)
},
)
tf_actual = ImageTransforms(tf_cfg)
tf_expected = SharpnessJitter(sharpness=min_max)
@@ -342,7 +362,9 @@ def test_save_all_transforms(img_tensor_factory, tmp_path):
# Check if the combined transforms directory exists and contains the right files
combined_transforms_dir = tmp_path / "all"
assert combined_transforms_dir.exists(), "Combined transforms directory was not created."
assert combined_transforms_dir.exists(), (
"Combined transforms directory was not created."
)
assert any(combined_transforms_dir.iterdir()), (
"No transformed images found in combined transforms directory."
)
@@ -364,9 +386,9 @@ def test_save_each_transform(img_tensor_factory, tmp_path):
for transform in transforms:
transform_dir = tmp_path / transform
assert transform_dir.exists(), f"{transform} directory was not created."
assert any(
transform_dir.iterdir()
), f"No transformed images found in {transform} directory."
assert any(transform_dir.iterdir()), (
f"No transformed images found in {transform} directory."
)
# Check for specific files within each transform directory
expected_files = [f"{i}.png" for i in range(1, n_examples + 1)] + [

View File

@@ -176,7 +176,9 @@ def test_delta_timestamps_within_tolerance():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item[f"index{OnlineBuffer.IS_PAD_POSTFIX}"]
torch.testing.assert_close(data, torch.tensor([0, 2, 3]), msg="Data does not match expected values")
torch.testing.assert_close(
data, torch.tensor([0, 2, 3]), msg="Data does not match expected values"
)
assert not is_pad.any(), "Unexpected padding detected"
@@ -212,7 +214,9 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
buffer.tolerance_s = 0.04
item = buffer[2]
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), (
"Data does not match expected values"
)
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), (
"Padding does not match expected values"
)
@@ -275,7 +279,8 @@ def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_p
online_sampling_ratio=online_sampling_ratio,
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
)
@@ -297,7 +302,8 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(
online_drop_n_last_frames=1,
)
torch.testing.assert_close(
weights, torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0])
weights,
torch.tensor([0.05, 0.05, 0.05, 0.05, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0, 0.2, 0.0]),
)
@@ -318,4 +324,6 @@ def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp
online_sampling_ratio=0.5,
online_drop_n_last_frames=1,
)
torch.testing.assert_close(weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0]))
torch.testing.assert_close(
weights, torch.tensor([0.5, 0, 0.125, 0, 0.125, 0, 0.125, 0, 0.125, 0])
)

View File

@@ -18,8 +18,13 @@ import torch
from datasets import Dataset
from huggingface_hub import DatasetCard
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import create_lerobot_dataset_card, hf_transform_to_torch
from lerobot.common.datasets.push_dataset_to_hub.utils import (
calculate_episode_data_index,
)
from lerobot.common.datasets.utils import (
create_lerobot_dataset_card,
hf_transform_to_torch,
)
def test_default_parameters():

View File

@@ -210,7 +210,10 @@ def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks = {}
for task_index in range(total_tasks):
task_dict = {"task_index": task_index, "task": f"Perform action {task_index}."}
task_dict = {
"task_index": task_index,
"task": f"Perform action {task_index}.",
}
tasks[task_index] = task_dict
return tasks
@@ -297,8 +300,12 @@ def hf_dataset_factory(
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes.values():
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
timestamp_col = np.concatenate(
(timestamp_col, np.arange(ep_dict["length"]) / fps)
)
frame_index_col = np.concatenate(
(frame_index_col, np.arange(ep_dict["length"], dtype=int))
)
episode_index_col = np.concatenate(
(
episode_index_col,
@@ -385,7 +392,9 @@ def lerobot_dataset_metadata_factory(
episodes=episodes,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
) as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
@@ -433,7 +442,9 @@ def lerobot_dataset_factory(
if not stats:
stats = stats_factory(features=info["features"])
if not episodes_stats:
episodes_stats = episodes_stats_factory(features=info["features"], total_episodes=total_episodes)
episodes_stats = episodes_stats_factory(
features=info["features"], total_episodes=total_episodes
)
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
@@ -466,8 +477,12 @@ def lerobot_dataset_factory(
episodes=episode_dicts,
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata"
) as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.get_safe_version"
) as mock_get_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,

View File

@@ -59,7 +59,9 @@ def stats_path(stats_factory):
@pytest.fixture(scope="session")
def episodes_stats_path(episodes_stats_factory):
def _create_episodes_stats_jsonl_file(dir: Path, episodes_stats: list[dict] | None = None) -> Path:
def _create_episodes_stats_jsonl_file(
dir: Path, episodes_stats: list[dict] | None = None
) -> Path:
if not episodes_stats:
episodes_stats = episodes_stats_factory()
fpath = dir / EPISODES_STATS_PATH

View File

@@ -99,7 +99,13 @@ def mock_snapshot_download_factory(
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, EPISODES_STATS_PATH, TASKS_PATH, EPISODES_PATH]
meta_files = [
INFO_PATH,
STATS_PATH,
EPISODES_STATS_PATH,
TASKS_PATH,
EPISODES_PATH,
]
all_files.extend(meta_files)
data_files = []

View File

@@ -35,5 +35,7 @@ def optimizer(model_params):
@pytest.fixture
def scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
config = VQBeTSchedulerConfig(
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
)
return config.build(optimizer, num_training_steps=100)

View File

@@ -43,7 +43,9 @@ def test_diffuser_scheduler(optimizer):
def test_vqbet_scheduler(optimizer):
config = VQBeTSchedulerConfig(num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5)
config = VQBeTSchedulerConfig(
num_warmup_steps=10, num_vqvae_training_steps=20, num_cycles=0.5
)
scheduler = config.build(optimizer, num_training_steps=100)
assert isinstance(scheduler, LambdaLR)

View File

@@ -59,16 +59,33 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
}
info = info_factory(
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_episodes=1,
total_frames=1,
camera_features=camera_features,
motor_features=motor_features,
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@@ -81,7 +98,8 @@ def test_get_policy_and_config_classes(policy_name: str):
policy_cfg = make_policy_config(policy_name)
assert policy_cls.name == policy_name
assert issubclass(
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
policy_cfg.__class__,
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
)
@@ -92,7 +110,13 @@ def test_get_policy_and_config_classes(policy_name: str):
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
{"task": "AlohaInsertion-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_insertion_scripted",
"aloha",
@@ -172,11 +196,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert set(batch) == set(batch_), (
"Batch keys are not the same after a forward pass."
)
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
torch.equal(batch[k], batch_[k])
if isinstance(batch[k], torch.Tensor)
else batch[k] == batch_[k]
for k in batch
), "Batch values are not the same after a forward pass."
@@ -215,8 +241,12 @@ def test_act_backbone_lr():
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
dataset=DatasetConfig(
repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]
),
policy=make_policy_config(
"act", optimizer_lr=0.01, optimizer_lr_backbone=0.001
),
)
cfg.validate() # Needed for auto-setting some parameters
@@ -239,7 +269,9 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
@@ -251,7 +283,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
@@ -260,7 +294,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
torch.testing.assert_close(
list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0
)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
@@ -400,7 +436,9 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
def test_backward_compatibility(
ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str
):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@@ -414,13 +452,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
"""
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
artifact_dir = (
Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, policy_name, policy_kwargs
)
for key in saved_output_dict:
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
@@ -429,8 +471,12 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
for key in saved_param_stats:
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
rtol, atol = (
(2e-3, 5e-6) if policy_name == "diffusion" else (None, None)
) # HACK
torch.testing.assert_close(
actions[key], saved_actions[key], rtol=rtol, atol=atol
)
def test_act_temporal_ensembler():

View File

@@ -179,7 +179,9 @@ def test_record_and_replay_and_policy(tmp_path, request, robot_type, mock):
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay_cfg = ReplayControlConfig(episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
replay_cfg = ReplayControlConfig(
episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False
)
replay(robot, replay_cfg)
policy_cfg = ACTConfig()
@@ -334,12 +336,12 @@ def test_record_with_event_rerecord_episode(tmp_path, request, robot_type, mock)
)
dataset = record(robot, rec_cfg)
assert not mock_events[
"rerecord_episode"
], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events[
"exit_early"
], "`exit_early` wasn't properly reset to False"
assert not mock_events["rerecord_episode"], (
"`rerecord_episode` wasn't properly reset to False"
)
assert not mock_events["exit_early"], (
"`exit_early` wasn't properly reset to False"
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@@ -389,7 +391,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
dataset = record(robot, rec_cfg)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert not mock_events["exit_early"], (
"`exit_early` wasn't properly reset to False"
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@@ -398,7 +402,9 @@ def test_record_with_event_exit_early(tmp_path, request, robot_type, mock):
[("koch", True, 0), ("koch", True, 1)],
)
@require_robot
def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, num_image_writer_processes):
def test_record_with_event_stop_recording(
tmp_path, request, robot_type, mock, num_image_writer_processes
):
robot_kwargs = {"robot_type": robot_type, "mock": mock}
if mock:
@@ -444,5 +450,7 @@ def test_record_with_event_stop_recording(tmp_path, request, robot_type, mock, n
dataset = record(robot, rec_cfg)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert not mock_events["exit_early"], (
"`exit_early` wasn't properly reset to False"
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"

View File

@@ -40,7 +40,10 @@ import pytest
import torch
from lerobot.common.robot_devices.robots.utils import make_robot
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.utils import (
RobotDeviceAlreadyConnectedError,
RobotDeviceNotConnectedError,
)
from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot
@@ -131,7 +134,9 @@ def test_robot(tmp_path, request, robot_type, mock):
if "image" in name:
# TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames
continue
torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1)
torch.testing.assert_close(
captured_observation[name], observation[name], rtol=1e-4, atol=1
)
assert captured_observation[name].shape == observation[name].shape
# Test send_action can run

View File

@@ -227,9 +227,9 @@ def test_resume_function(
config_dir = os.path.abspath(
os.path.join(test_file_dir, "..", "lerobot", "configs", "policy")
)
assert os.path.exists(
config_dir
), f"Config directory does not exist at {config_dir}"
assert os.path.exists(config_dir), (
f"Config directory does not exist at {config_dir}"
)
with initialize_config_dir(
config_dir=config_dir, job_name="test_app", version_base="1.2"

View File

@@ -26,10 +26,16 @@ from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.robot_devices.cameras.utils import Camera
from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device
from lerobot.common.robot_devices.motors.utils import MotorsBus
from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device
from lerobot.common.robot_devices.motors.utils import (
make_motors_bus as make_motors_bus_device,
)
from lerobot.common.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu"
DEVICE = (
os.environ.get("LEROBOT_TEST_DEVICE", "cuda")
if torch.cuda.is_available()
else "cpu"
)
TEST_ROBOT_TYPES = []
for robot_type in available_robots:
@@ -45,7 +51,9 @@ for motor_type in available_motors:
# Camera indices used for connecting physical cameras
OPENCV_CAMERA_INDEX = int(os.environ.get("LEROBOT_TEST_OPENCV_CAMERA_INDEX", 0))
INTELREALSENSE_SERIAL_NUMBER = int(os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614))
INTELREALSENSE_SERIAL_NUMBER = int(
os.environ.get("LEROBOT_TEST_INTELREALSENSE_SERIAL_NUMBER", 128422271614)
)
DYNAMIXEL_PORT = os.environ.get(
"LEROBOT_TEST_DYNAMIXEL_PORT", "/dev/tty.usbmodem575E0032081"

View File

@@ -18,7 +18,10 @@ from lerobot.common.utils.logging_utils import AverageMeter, MetricsTracker
@pytest.fixture
def mock_metrics():
return {"loss": AverageMeter("loss", ":.3f"), "accuracy": AverageMeter("accuracy", ":.2f")}
return {
"loss": AverageMeter("loss", ":.3f"),
"accuracy": AverageMeter("accuracy", ":.2f"),
}
def test_average_meter_initialization():
@@ -58,7 +61,11 @@ def test_average_meter_str():
def test_metrics_tracker_initialization(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=10
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=10,
)
assert tracker.steps == 10
assert tracker.samples == 10 * 32
@@ -70,7 +77,11 @@ def test_metrics_tracker_initialization(mock_metrics):
def test_metrics_tracker_step(mock_metrics):
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics, initial_step=5
batch_size=32,
num_frames=1000,
num_episodes=50,
metrics=mock_metrics,
initial_step=5,
)
tracker.step()
assert tracker.steps == 6
@@ -80,7 +91,9 @@ def test_metrics_tracker_step(mock_metrics):
def test_metrics_tracker_getattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
assert tracker.loss == mock_metrics["loss"]
assert tracker.accuracy == mock_metrics["accuracy"]
with pytest.raises(AttributeError):
@@ -88,13 +101,17 @@ def test_metrics_tracker_getattr(mock_metrics):
def test_metrics_tracker_setattr(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss = 2.0
assert tracker.loss.val == 2.0
def test_metrics_tracker_str(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss.update(3.456, 1)
tracker.accuracy.update(0.876, 1)
output = str(tracker)
@@ -103,7 +120,9 @@ def test_metrics_tracker_str(mock_metrics):
def test_metrics_tracker_to_dict(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss.update(5, 2)
metrics_dict = tracker.to_dict()
assert isinstance(metrics_dict, dict)
@@ -112,7 +131,9 @@ def test_metrics_tracker_to_dict(mock_metrics):
def test_metrics_tracker_reset_averages(mock_metrics):
tracker = MetricsTracker(batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics)
tracker = MetricsTracker(
batch_size=32, num_frames=1000, num_episodes=50, metrics=mock_metrics
)
tracker.loss.update(10, 3)
tracker.accuracy.update(0.95, 5)
tracker.reset_averages()

View File

@@ -118,5 +118,9 @@ def test_seeded_context(fixed_seed):
seeded_val2 = (random.random(), np.random.rand(), torch.rand(1).item())
assert seeded_val1 == seeded_val2
assert all(a != b for a, b in zip(val1, seeded_val1, strict=True)) # changed inside the context
assert all(a != b for a, b in zip(val2, seeded_val2, strict=True)) # changed again after exiting
assert all(
a != b for a, b in zip(val1, seeded_val1, strict=True)
) # changed inside the context
assert all(
a != b for a, b in zip(val2, seeded_val2, strict=True)
) # changed again after exiting

View File

@@ -91,7 +91,9 @@ def test_save_training_state(tmp_path, optimizer, scheduler):
def test_save_load_training_state(tmp_path, optimizer, scheduler):
save_training_state(tmp_path, 10, optimizer, scheduler)
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(tmp_path, optimizer, scheduler)
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
tmp_path, optimizer, scheduler
)
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler