WIP WIP train.py works, loss going down WIP eval.py Fix WIP (eval running, TODO: verify results reproduced) Eval works! (testing reproducibility) WIP pretrained model pusht reproduces same results as torchrl pretrained model pusht reproduces same results as torchrl Remove AbstractPolicy, Move all queues in select_action WIP test_datasets passed (TODO: re-enable NormalizeTransform)
137 lines
5.0 KiB
Python
137 lines
5.0 KiB
Python
import pytest
|
|
from tensordict import TensorDict
|
|
from tensordict.nn import TensorDictModule
|
|
import torch
|
|
from torchrl.data import UnboundedContinuousTensorSpec
|
|
from torchrl.envs import EnvBase
|
|
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.envs.factory import make_env
|
|
from lerobot.common.datasets.factory import make_dataset
|
|
from lerobot.common.policies.abstract import AbstractPolicy
|
|
from lerobot.common.utils import init_hydra_config
|
|
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
|
|
|
@pytest.mark.parametrize(
|
|
"env_name,policy_name,extra_overrides",
|
|
[
|
|
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
|
("pusht", "tdmpc", ["policy.mpc=false"]),
|
|
("pusht", "diffusion", []),
|
|
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
|
|
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
|
|
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
|
|
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
|
|
# TODO(aliberts): simxarm not working with diffusion
|
|
# ("simxarm", "diffusion", []),
|
|
],
|
|
)
|
|
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
|
"""
|
|
Tests:
|
|
- Making the policy object.
|
|
- Updating the policy.
|
|
- Using the policy to select actions at inference time.
|
|
"""
|
|
cfg = init_hydra_config(
|
|
DEFAULT_CONFIG_PATH,
|
|
overrides=[
|
|
f"env={env_name}",
|
|
f"policy={policy_name}",
|
|
f"device={DEVICE}",
|
|
]
|
|
+ extra_overrides
|
|
)
|
|
# Check that we can make the policy object.
|
|
policy = make_policy(cfg)
|
|
# Check that we run select_actions and get the appropriate output.
|
|
dataset = make_dataset(cfg)
|
|
env = make_env(cfg, transform=dataset.transform)
|
|
|
|
if env_name != "aloha":
|
|
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
|
|
# seq_length as a list is not supported for now.
|
|
policy.update(dataset, torch.tensor(0, device=DEVICE))
|
|
|
|
action = policy(
|
|
env.observation_spec.rand()["observation"].to(DEVICE),
|
|
torch.tensor(0, device=DEVICE),
|
|
)
|
|
assert action.shape == env.action_spec.shape
|
|
|
|
|
|
def test_abstract_policy_forward():
|
|
"""
|
|
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
|
|
- The policy is invoked the expected number of times during a rollout.
|
|
- The environment's termination condition is respected even when part way through an action trajectory.
|
|
- The observations are returned correctly.
|
|
"""
|
|
|
|
n_action_steps = 8 # our test policy will output 8 action step horizons
|
|
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
|
|
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
|
|
|
|
# A minimal environment for testing.
|
|
class StubEnv(EnvBase):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
|
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
|
|
|
|
def _step(self, tensordict: TensorDict) -> TensorDict:
|
|
self.invocation_count += 1
|
|
return TensorDict(
|
|
{
|
|
"observation": torch.tensor([self.invocation_count]),
|
|
"reward": torch.tensor([self.invocation_count]),
|
|
"terminated": torch.tensor(
|
|
tensordict["action"].item() == terminate_at
|
|
),
|
|
}
|
|
)
|
|
|
|
def _reset(self, tensordict: TensorDict) -> TensorDict:
|
|
self.invocation_count = 0
|
|
return TensorDict(
|
|
{
|
|
"observation": torch.tensor([self.invocation_count]),
|
|
"reward": torch.tensor([self.invocation_count]),
|
|
}
|
|
)
|
|
|
|
def _set_seed(self, seed: int | None):
|
|
return
|
|
|
|
class StubPolicy(AbstractPolicy):
|
|
name = "stub"
|
|
|
|
def __init__(self):
|
|
super().__init__(n_action_steps)
|
|
self.n_policy_invocations = 0
|
|
|
|
def update(self):
|
|
pass
|
|
|
|
def select_actions(self):
|
|
self.n_policy_invocations += 1
|
|
return torch.stack(
|
|
[torch.tensor([i]) for i in range(self.n_action_steps)]
|
|
).unsqueeze(0)
|
|
|
|
env = StubEnv()
|
|
policy = StubPolicy()
|
|
policy = TensorDictModule(
|
|
policy,
|
|
in_keys=[],
|
|
out_keys=["action"],
|
|
)
|
|
|
|
# Keep track to make sure the policy is called the expected number of times
|
|
rollout = env.rollout(rollout_max_steps, policy)
|
|
|
|
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
|
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
|
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))
|