[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by Michel Aractingi
parent cdcf346061
commit 1c8daf11fd
95 changed files with 1592 additions and 491 deletions

View File

@@ -59,16 +59,33 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
"observation.state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
"names": [
"shoulder_pan",
"shoulder_lift",
"elbow_flex",
"wrist_flex",
"wrist_roll",
"gripper",
],
},
}
info = info_factory(
total_episodes=1, total_frames=1, camera_features=camera_features, motor_features=motor_features
total_episodes=1,
total_frames=1,
camera_features=camera_features,
motor_features=motor_features,
)
ds_meta = lerobot_dataset_metadata_factory(root=tmp_path / "init", info=info)
return ds_meta
@@ -81,7 +98,8 @@ def test_get_policy_and_config_classes(policy_name: str):
policy_cfg = make_policy_config(policy_name)
assert policy_cls.name == policy_name
assert issubclass(
policy_cfg.__class__, inspect.signature(policy_cls.__init__).parameters["config"].annotation
policy_cfg.__class__,
inspect.signature(policy_cls.__init__).parameters["config"].annotation,
)
@@ -92,7 +110,13 @@ def test_get_policy_and_config_classes(policy_name: str):
("lerobot/pusht", "pusht", {}, "diffusion", {}),
("lerobot/pusht", "pusht", {}, "vqbet", {}),
("lerobot/pusht", "pusht", {}, "act", {}),
("lerobot/aloha_sim_insertion_human", "aloha", {"task": "AlohaInsertion-v0"}, "act", {}),
(
"lerobot/aloha_sim_insertion_human",
"aloha",
{"task": "AlohaInsertion-v0"},
"act",
{},
),
(
"lerobot/aloha_sim_insertion_scripted",
"aloha",
@@ -172,11 +196,13 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
# Test updating the policy (and test that it does not mutate the batch)
batch_ = deepcopy(batch)
policy.forward(batch)
assert set(batch) == set(
batch_
), "Batch keys are not the same after a forward pass."
assert set(batch) == set(batch_), (
"Batch keys are not the same after a forward pass."
)
assert all(
torch.equal(batch[k], batch_[k]) if isinstance(batch[k], torch.Tensor) else batch[k] == batch_[k]
torch.equal(batch[k], batch_[k])
if isinstance(batch[k], torch.Tensor)
else batch[k] == batch_[k]
for k in batch
), "Batch values are not the same after a forward pass."
@@ -215,8 +241,12 @@ def test_act_backbone_lr():
cfg = TrainPipelineConfig(
# TODO(rcadene, aliberts): remove dataset download
dataset=DatasetConfig(repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]),
policy=make_policy_config("act", optimizer_lr=0.01, optimizer_lr_backbone=0.001),
dataset=DatasetConfig(
repo_id="lerobot/aloha_sim_insertion_scripted", episodes=[0]
),
policy=make_policy_config(
"act", optimizer_lr=0.01, optimizer_lr_backbone=0.001
),
)
cfg.validate() # Needed for auto-setting some parameters
@@ -239,7 +269,9 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
@@ -251,7 +283,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.output_features = {
key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION
}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
@@ -260,7 +294,9 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
policy.save_pretrained(save_dir)
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
torch.testing.assert_close(
list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0
)
@pytest.mark.parametrize("insert_temporal_dim", [False, True])
@@ -400,7 +436,9 @@ def test_normalize(insert_temporal_dim):
# pass if it's run on another platform due to floating point errors
@require_x86_64_kernel
@require_cpu
def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str):
def test_backward_compatibility(
ds_repo_id: str, policy_name: str, policy_kwargs: dict, file_name_extra: str
):
"""
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
@@ -414,13 +452,17 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
6. Remember to stage and commit the resulting changes to `tests/artifacts`.
"""
ds_name = ds_repo_id.split("/")[-1]
artifact_dir = Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
artifact_dir = (
Path("tests/artifacts/policies") / f"{ds_name}_{policy_name}_{file_name_extra}"
)
saved_output_dict = load_file(artifact_dir / "output_dict.safetensors")
saved_grad_stats = load_file(artifact_dir / "grad_stats.safetensors")
saved_param_stats = load_file(artifact_dir / "param_stats.safetensors")
saved_actions = load_file(artifact_dir / "actions.safetensors")
output_dict, grad_stats, param_stats, actions = get_policy_stats(ds_repo_id, policy_name, policy_kwargs)
output_dict, grad_stats, param_stats, actions = get_policy_stats(
ds_repo_id, policy_name, policy_kwargs
)
for key in saved_output_dict:
torch.testing.assert_close(output_dict[key], saved_output_dict[key])
@@ -429,8 +471,12 @@ def test_backward_compatibility(ds_repo_id: str, policy_name: str, policy_kwargs
for key in saved_param_stats:
torch.testing.assert_close(param_stats[key], saved_param_stats[key])
for key in saved_actions:
rtol, atol = (2e-3, 5e-6) if policy_name == "diffusion" else (None, None) # HACK
torch.testing.assert_close(actions[key], saved_actions[key], rtol=rtol, atol=atol)
rtol, atol = (
(2e-3, 5e-6) if policy_name == "diffusion" else (None, None)
) # HACK
torch.testing.assert_close(
actions[key], saved_actions[key], rtol=rtol, atol=atol
)
def test_act_temporal_ensembler():