Files
lerobot/tests/test_policies.py
Alexander Soare 1a1308d62f fix environment seeding
add fixes for reproducibility

only try to start env if it is closed

revision

fix normalization and data type

Improve README

Improve README

Tests are passing, Eval pretrained model works, Add gif

Update gif

Update gif

Update gif

Update gif

Update README

Update README

update minor

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Update README.md

Co-authored-by: Simon Alibert <75076266+aliberts@users.noreply.github.com>

Address suggestions

Update thumbnail + stats

Update thumbnail + stats

Update README.md

Co-authored-by: Alexander Soare <alexander.soare159@gmail.com>

Add more comments

Add test_examples.py
2024-03-26 10:10:43 +00:00

137 lines
5.0 KiB
Python

from omegaconf import open_dict
import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
import torch
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import EnvBase
from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy
from .utils import DEVICE, init_config
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
("simxarm", "tdmpc", ["policy.mpc=true"]),
("pusht", "tdmpc", ["policy.mpc=false"]),
("pusht", "diffusion", []),
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_human"]),
("aloha", "act", ["env.task=sim_insertion", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=sim_transfer_cube", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): simxarm not working with diffusion
# ("simxarm", "diffusion", []),
],
)
def test_concrete_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Updating the policy.
- Using the policy to select actions at inference time.
"""
cfg = init_config(
overrides=[
f"env={env_name}",
f"policy={policy_name}",
f"device={DEVICE}",
]
+ extra_overrides
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that we run select_actions and get the appropriate output.
offline_buffer = make_offline_buffer(cfg)
env = make_env(cfg, transform=offline_buffer.transform)
if env_name != "aloha":
# TODO(alexander-soare): Fix this part of the test. PrioritizedSliceSampler raises NotImplementedError:
# seq_length as a list is not supported for now.
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
action = policy(
env.observation_spec.rand()["observation"].to(DEVICE),
torch.tensor(0, device=DEVICE),
)
assert action.shape == env.action_spec.shape
def test_abstract_policy_forward():
"""
Given an underlying policy that produces an action trajectory with n_action_steps actions, checks that:
- The policy is invoked the expected number of times during a rollout.
- The environment's termination condition is respected even when part way through an action trajectory.
- The observations are returned correctly.
"""
n_action_steps = 8 # our test policy will output 8 action step horizons
terminate_at = 10 # some number that is more than n_action_steps but not a multiple
rollout_max_steps = terminate_at + 1 # some number greater than terminate_at
# A minimal environment for testing.
class StubEnv(EnvBase):
def __init__(self):
super().__init__()
self.action_spec = UnboundedContinuousTensorSpec(shape=(1,))
self.reward_spec = UnboundedContinuousTensorSpec(shape=(1,))
def _step(self, tensordict: TensorDict) -> TensorDict:
self.invocation_count += 1
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
"terminated": torch.tensor(
tensordict["action"].item() == terminate_at
),
}
)
def _reset(self, tensordict: TensorDict) -> TensorDict:
self.invocation_count = 0
return TensorDict(
{
"observation": torch.tensor([self.invocation_count]),
"reward": torch.tensor([self.invocation_count]),
}
)
def _set_seed(self, seed: int | None):
return
class StubPolicy(AbstractPolicy):
name = "stub"
def __init__(self):
super().__init__(n_action_steps)
self.n_policy_invocations = 0
def update(self):
pass
def select_actions(self):
self.n_policy_invocations += 1
return torch.stack(
[torch.tensor([i]) for i in range(self.n_action_steps)]
).unsqueeze(0)
env = StubEnv()
policy = StubPolicy()
policy = TensorDictModule(
policy,
in_keys=[],
out_keys=["action"],
)
# Keep track to make sure the policy is called the expected number of times
rollout = env.rollout(rollout_max_steps, policy)
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))