Refactor TD-MPC (#103)

Co-authored-by: Cadene <re.cadene@gmail.com>
Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
Alexander Soare
2024-05-01 16:40:04 +01:00
committed by GitHub
parent a4891095e4
commit d1855a202a
17 changed files with 1105 additions and 1205 deletions

View File

@@ -6,7 +6,7 @@ import pytest
import lerobot
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from tests.utils import require_env

View File

@@ -19,10 +19,6 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
@pytest.mark.parametrize("policy_name", available_policies)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
assert policy_cls.name == policy_name
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
@@ -32,8 +28,7 @@ def test_get_policy_and_config_classes(policy_name: str):
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
# ("xarm", "tdmpc", ["policy.mpc=true"]),
# ("pusht", "tdmpc", ["policy.mpc=false"]),
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
(
@@ -103,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
policy.forward(batch, step=0)
policy.forward(batch)
# reset the policy and environment
policy.reset()
@@ -117,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# get the next action for the environment
with torch.inference_mode():
action = policy.select_action(observation, step=0)
action = policy.select_action(observation)
# convert action to cpu numpy array
action = postprocess_action(action)
@@ -129,20 +124,12 @@ def test_policy(env_name, policy_name, extra_overrides):
@pytest.mark.parametrize("policy_name", available_policies)
def test_policy_defaults(policy_name: str):
"""Check that the policy can be instantiated with defaults."""
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy_cls()
@pytest.mark.parametrize("policy_name", available_policies)
def test_save_and_load_pretrained(policy_name: str):
if policy_name == "tdmpc":
with pytest.raises(NotImplementedError):
get_policy_and_config_classes(policy_name)
return
policy_cls, _ = get_policy_and_config_classes(policy_name)
policy: Policy = policy_cls()
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"