Merge branch 'main' into thomwolf_2024_06_18_fix_normalization
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3f4e0e525aeb22ea94b79e26b39a87e6f2da9fbee33e493906aaf2aad9a7c1ef
|
||||
size 515400
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6dc658a1c1616c7d1c211eb8f87cec3d44f7b67d6b3cea7a6ce12b32d74674da
|
||||
size 31688
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:03971f92b7907b6b7e6ac207f508666104cd84c26c5276f510c431db604e188b
|
||||
size 68
|
||||
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:01d993c67a9267032fe9fbeff20b4359c209464976ea503040a0a76ae213450a
|
||||
size 33408
|
||||
@@ -89,8 +89,8 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
return output_dict, grad_stats, param_stats, actions
|
||||
|
||||
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}"
|
||||
def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_overrides, file_name_extra):
|
||||
env_policy_dir = Path(output_dir) / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
|
||||
if env_policy_dir.exists():
|
||||
print(f"Overwrite existing safetensors in '{env_policy_dir}':")
|
||||
@@ -108,15 +108,17 @@ def save_policy_to_safetensors(output_dir, env_name, policy_name, extra_override
|
||||
|
||||
if __name__ == "__main__":
|
||||
env_policies = [
|
||||
("xarm", "tdmpc", []),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
# ("xarm", "tdmpc", []),
|
||||
# (
|
||||
# "pusht",
|
||||
# "diffusion",
|
||||
# ["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
# ),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
# ("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
# ("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
]
|
||||
for env, policy, extra_overrides in env_policies:
|
||||
save_policy_to_safetensors("tests/data/save_policy_to_safetensors", env, policy, extra_overrides)
|
||||
for env, policy, extra_overrides, file_name_extra in env_policies:
|
||||
save_policy_to_safetensors(
|
||||
"tests/data/save_policy_to_safetensors", env, policy, extra_overrides, file_name_extra
|
||||
)
|
||||
|
||||
@@ -624,24 +624,26 @@ def test_normalize(insert_temporal_dim):
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, policy_name, extra_overrides",
|
||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||
[
|
||||
("xarm", "tdmpc", []),
|
||||
("xarm", "tdmpc", [], ""),
|
||||
(
|
||||
"pusht",
|
||||
"diffusion",
|
||||
["policy.n_action_steps=8", "policy.num_inference_steps=10", "policy.down_dims=[128, 256, 512]"],
|
||||
"",
|
||||
),
|
||||
("aloha", "act", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"]),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"]),
|
||||
("aloha", "act", ["policy.n_action_steps=10"], ""),
|
||||
("aloha", "act", ["policy.n_action_steps=1000", "policy.chunk_size=1000"], "_1000_steps"),
|
||||
("dora_aloha_real", "act_real", ["policy.n_action_steps=10"], ""),
|
||||
("dora_aloha_real", "act_real_no_state", ["policy.n_action_steps=10"], ""),
|
||||
],
|
||||
)
|
||||
# As artifacts have been generated on an x86_64 kernel, this test won't
|
||||
# pass if it's run on another platform due to floating point errors
|
||||
@require_x86_64_kernel
|
||||
@require_cpu
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||
def test_backward_compatibility(env_name, policy_name, extra_overrides, file_name_extra):
|
||||
"""
|
||||
NOTE: If this test does not pass, and you have intentionally changed something in the policy:
|
||||
1. Inspect the differences in policy outputs and make sure you can account for them. Your PR should
|
||||
@@ -653,7 +655,9 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides):
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
6. Remember to stage and commit the resulting changes to `tests/data`.
|
||||
"""
|
||||
env_policy_dir = Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}"
|
||||
env_policy_dir = (
|
||||
Path("tests/data/save_policy_to_safetensors") / f"{env_name}_{policy_name}{file_name_extra}"
|
||||
)
|
||||
saved_output_dict = load_file(env_policy_dir / "output_dict.safetensors")
|
||||
saved_grad_stats = load_file(env_policy_dir / "grad_stats.safetensors")
|
||||
saved_param_stats = load_file(env_policy_dir / "param_stats.safetensors")
|
||||
|
||||
Reference in New Issue
Block a user