Rename dora_aloha_real, WIP test_policies

This commit is contained in:
Remi Cadene
2024-05-30 17:54:59 +00:00
parent b7b5c3b4ff
commit 671ad93b6c
4 changed files with 35 additions and 10 deletions

View File

@@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.utils.utils import init_hydra_config
from tests.scripts.save_policy_to_safetensor import get_policy_stats
from tests.scripts.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
@@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
),
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
("dora_aloha_real", "act_real", []),
("dora_aloha_real", "act_real_no_state", []),
],
)
@require_env
@@ -291,6 +293,8 @@ def test_normalize(insert_temporal_dim):
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
],
)
# As artifacts have been generated on an x86_64 kernel, this test won't