From 92d1aecb4034f5e7e1610fb8f299a22f759de5f3 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Thu, 30 May 2024 17:54:59 +0000 Subject: [PATCH] Rename dora_aloha_real, WIP test_policies --- lerobot/__init__.py | 6 +++--- lerobot/configs/env/aloha_real.yaml | 13 ------------- tests/scripts/save_policy_to_safetensors.py | 6 +++--- 3 files changed, 6 insertions(+), 19 deletions(-) delete mode 100644 lerobot/configs/env/aloha_real.yaml diff --git a/lerobot/__init__.py b/lerobot/__init__.py index c3f67a32..4d8d6ec5 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -55,7 +55,7 @@ available_tasks_per_env = { ], "pusht": ["PushT-v0"], "xarm": ["XarmLift-v0"], - "dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"], + "dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"], } available_envs = list(available_tasks_per_env.keys()) @@ -81,7 +81,7 @@ available_datasets_per_env = { "lerobot/xarm_push_medium_image", "lerobot/xarm_push_medium_replay_image", ], - "dora_aloha_real": [ + "dora": [ "lerobot/aloha_static_battery", "lerobot/aloha_static_candy", "lerobot/aloha_static_coffee", @@ -139,7 +139,7 @@ available_policies = [ # keys and values refer to yaml files available_policies_per_env = { "aloha": ["act"], - "aloha_real": ["act"], + "dora": ["act"], "pusht": ["diffusion"], "xarm": ["tdmpc"], "dora_aloha_real": ["act_real"], diff --git a/lerobot/configs/env/aloha_real.yaml b/lerobot/configs/env/aloha_real.yaml deleted file mode 100644 index 088781d4..00000000 --- a/lerobot/configs/env/aloha_real.yaml +++ /dev/null @@ -1,13 +0,0 @@ -# @package _global_ - -fps: 30 - -env: - name: dora - task: DoraAloha-v0 - state_dim: 14 - action_dim: 14 - fps: ${fps} - episode_length: 400 - gym: - fps: ${fps} diff --git a/tests/scripts/save_policy_to_safetensors.py b/tests/scripts/save_policy_to_safetensors.py index 961b7cef..f867a5e8 100644 --- a/tests/scripts/save_policy_to_safetensors.py +++ b/tests/scripts/save_policy_to_safetensors.py @@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides): batch = next(iter(dataloader)) obs = {} for k in batch: - if k.startswith("observation"): + if "observation" in k: obs[k] = batch[k] if "n_action_steps" in cfg.policy: @@ -115,8 +115,8 @@ if __name__ == "__main__": ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"], ), ("aloha", "act", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]), - ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]), + ("dora_aloha_real", "act_real", []), + ("dora_aloha_real", "act_real_no_state", []), ] for env, policy, extra_overrides in env_policies: save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)