forked from tangger/lerobot
Refactor TD-MPC (#103)
Co-authored-by: Cadene <re.cadene@gmail.com> Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import pytest
|
||||
import lerobot
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
from lerobot.common.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
|
||||
from tests.utils import require_env
|
||||
|
||||
|
||||
|
||||
@@ -19,10 +19,6 @@ from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_get_policy_and_config_classes(policy_name: str):
|
||||
"""Check that the correct policy and config classes are returned."""
|
||||
if policy_name == "tdmpc":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_policy_and_config_classes(policy_name)
|
||||
return
|
||||
policy_cls, config_cls = get_policy_and_config_classes(policy_name)
|
||||
assert policy_cls.name == policy_name
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
@@ -32,8 +28,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
# ("xarm", "tdmpc", ["policy.mpc=true"]),
|
||||
# ("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("xarm", "tdmpc", ["policy.use_mpc=true", "dataset_repo_id=lerobot/xarm_lift_medium"]),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_repo_id=lerobot/aloha_sim_insertion_human"]),
|
||||
(
|
||||
@@ -103,7 +98,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy
|
||||
policy.forward(batch, step=0)
|
||||
policy.forward(batch)
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
@@ -117,7 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation, step=0)
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
@@ -129,20 +124,12 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_policy_defaults(policy_name: str):
|
||||
"""Check that the policy can be instantiated with defaults."""
|
||||
if policy_name == "tdmpc":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_policy_and_config_classes(policy_name)
|
||||
return
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy_cls()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_name", available_policies)
|
||||
def test_save_and_load_pretrained(policy_name: str):
|
||||
if policy_name == "tdmpc":
|
||||
with pytest.raises(NotImplementedError):
|
||||
get_policy_and_config_classes(policy_name)
|
||||
return
|
||||
policy_cls, _ = get_policy_and_config_classes(policy_name)
|
||||
policy: Policy = policy_cls()
|
||||
save_dir = "/tmp/test_save_and_load_pretrained_{policy_cls.__name__}"
|
||||
|
||||
Reference in New Issue
Block a user