Rename dora_aloha_real, WIP test_policies
This commit is contained in:
@@ -75,15 +75,16 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
# HACK: We reload a batch with no delta_timestamps as `select_action` won't expect a timestamps dimension
|
||||
dataset.delta_timestamps = None
|
||||
batch = next(iter(dataloader))
|
||||
obs = {
|
||||
k: batch[k]
|
||||
for k in batch
|
||||
if k in ["observation.image", "observation.images.top", "observation.state"]
|
||||
}
|
||||
obs = {}
|
||||
for k in batch:
|
||||
if "observation" in k:
|
||||
obs[k] = batch[k]
|
||||
|
||||
if "n_action_steps" in cfg.policy:
|
||||
actions_queue = cfg.policy.n_action_steps
|
||||
else:
|
||||
actions_queue = cfg.policy.n_action_repeats
|
||||
|
||||
actions_queue = (
|
||||
cfg.policy.n_action_steps if "n_action_steps" in cfg.policy else cfg.policy.n_action_repeats
|
||||
)
|
||||
actions = {str(i): policy.select_action(obs).contiguous() for i in range(actions_queue)}
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
@@ -114,6 +115,8 @@ if __name__ == "__main__":
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", []),
|
||||
("dora_aloha_real", "act_real_no_state", []),
|
||||
]
|
||||
for env, policy, extra_overrides in env_policies:
|
||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||
@@ -30,7 +30,7 @@ from lerobot.common.policies.factory import get_policy_and_config_classes, make_
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.scripts.save_policy_to_safetensor import get_policy_stats
|
||||
from tests.scripts.save_policy_to_safetensors import get_policy_stats
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_cpu, require_env, require_x86_64_kernel
|
||||
|
||||
|
||||
@@ -72,6 +72,8 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
),
|
||||
# Note: these parameters also need custom logic in the test function for overriding the Hydra config.
|
||||
("pusht", "act", ["env.task=PushT-v0", "dataset_repo_id=lerobot/pusht"]),
|
||||
("dora_aloha_real", "act_real", []),
|
||||
("dora_aloha_real", "act_real_no_state", []),
|
||||
],
|
||||
)
|
||||
@require_env
|
||||
@@ -291,6 +293,8 @@ def test_normalize(insert_temporal_dim):
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
|
||||
Reference in New Issue
Block a user