forked from tangger/lerobot
Rename dora_aloha_real, WIP test_policies
This commit is contained in:
@@ -55,7 +55,7 @@ available_tasks_per_env = {
|
|||||||
],
|
],
|
||||||
"pusht": ["PushT-v0"],
|
"pusht": ["PushT-v0"],
|
||||||
"xarm": ["XarmLift-v0"],
|
"xarm": ["XarmLift-v0"],
|
||||||
"dora_aloha_real": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
"dora": ["DoraAloha-v0", "DoraKoch-v0", "DoraReachy2-v0"],
|
||||||
}
|
}
|
||||||
available_envs = list(available_tasks_per_env.keys())
|
available_envs = list(available_tasks_per_env.keys())
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ available_datasets_per_env = {
|
|||||||
"lerobot/xarm_push_medium_image",
|
"lerobot/xarm_push_medium_image",
|
||||||
"lerobot/xarm_push_medium_replay_image",
|
"lerobot/xarm_push_medium_replay_image",
|
||||||
],
|
],
|
||||||
"dora_aloha_real": [
|
"dora": [
|
||||||
"lerobot/aloha_static_battery",
|
"lerobot/aloha_static_battery",
|
||||||
"lerobot/aloha_static_candy",
|
"lerobot/aloha_static_candy",
|
||||||
"lerobot/aloha_static_coffee",
|
"lerobot/aloha_static_coffee",
|
||||||
@@ -139,7 +139,7 @@ available_policies = [
|
|||||||
# keys and values refer to yaml files
|
# keys and values refer to yaml files
|
||||||
available_policies_per_env = {
|
available_policies_per_env = {
|
||||||
"aloha": ["act"],
|
"aloha": ["act"],
|
||||||
"aloha_real": ["act"],
|
"dora": ["act"],
|
||||||
"pusht": ["diffusion"],
|
"pusht": ["diffusion"],
|
||||||
"xarm": ["tdmpc"],
|
"xarm": ["tdmpc"],
|
||||||
"dora_aloha_real": ["act_real"],
|
"dora_aloha_real": ["act_real"],
|
||||||
|
|||||||
13
lerobot/configs/env/aloha_real.yaml
vendored
13
lerobot/configs/env/aloha_real.yaml
vendored
@@ -1,13 +0,0 @@
|
|||||||
# @package _global_
|
|
||||||
|
|
||||||
fps: 30
|
|
||||||
|
|
||||||
env:
|
|
||||||
name: dora
|
|
||||||
task: DoraAloha-v0
|
|
||||||
state_dim: 14
|
|
||||||
action_dim: 14
|
|
||||||
fps: ${fps}
|
|
||||||
episode_length: 400
|
|
||||||
gym:
|
|
||||||
fps: ${fps}
|
|
||||||
@@ -77,7 +77,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
|||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
obs = {}
|
obs = {}
|
||||||
for k in batch:
|
for k in batch:
|
||||||
if k.startswith("observation"):
|
if "observation" in k:
|
||||||
obs[k] = batch[k]
|
obs[k] = batch[k]
|
||||||
|
|
||||||
if "n_action_steps" in cfg.policy:
|
if "n_action_steps" in cfg.policy:
|
||||||
@@ -115,8 +115,8 @@ if __name__ == "__main__":
|
|||||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||||
),
|
),
|
||||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
("dora_aloha_real", "act_real", []),
|
||||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
("dora_aloha_real", "act_real_no_state", []),
|
||||||
]
|
]
|
||||||
for env, policy, extra_overrides in env_policies:
|
for env, policy, extra_overrides in env_policies:
|
||||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||||
|
|||||||
Reference in New Issue
Block a user