Bug fix: missing attention mask in VAE encoder in ACT policy (#279)

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>
This commit is contained in:
Thomas Wolf
2024-06-19 13:07:21 +02:00
committed by GitHub
parent 56199fb76f
commit 48951662f2
7 changed files with 60 additions and 25 deletions

View File

@@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
return output_dict, grad_stats, param_stats, actions
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
if env_policy_dir.exists():
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
@@ -108,15 +108,17 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
if __name__ == "__main__":
env_policies = [
("xarm", "tdmpc", []),
(
"pusht",
"diffusion",
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
),
("aloha", "act", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
# ("xarm", "tdmpc", []),
# (
# "pusht",
# "diffusion",
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
# ),
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
]
for env, policy, extra_overrides in env_policies:
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
for env, policy, extra_overrides, file_name_extra in env_policies:
save_policy_to_safetensors(
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
)