forked from tangger/lerobot
backup wip
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,3 @@
|
||||
|
||||
from omegaconf import open_dict
|
||||
import pytest
|
||||
from tensordict import TensorDict
|
||||
@@ -16,35 +15,50 @@ from .utils import DEVICE, init_config
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name",
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
("simxarm", "tdmpc"),
|
||||
("pusht", "tdmpc"),
|
||||
("simxarm", "diffusion"),
|
||||
("pusht", "diffusion"),
|
||||
("simxarm", "tdmpc", ["policy.mpc=true"]),
|
||||
("pusht", "tdmpc", ["policy.mpc=false"]),
|
||||
("simxarm", "diffusion", []),
|
||||
("pusht", "diffusion", []),
|
||||
("aloha", "act", ["env.task=sim_insertion_scripted"]),
|
||||
],
|
||||
)
|
||||
def test_factory(env_name, policy_name):
|
||||
def test_concrete_policy(env_name, policy_name, extra_overrides):
|
||||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
"""
|
||||
cfg = init_config(
|
||||
overrides=[
|
||||
f"env={env_name}",
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
+ extra_overrides
|
||||
)
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
# Check that we run select_action and get the appropriate output.
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
if env_name == "simxarm":
|
||||
# TODO(rcadene): Not implemented
|
||||
return
|
||||
if policy_name == "tdmpc":
|
||||
# TODO(alexander-soare): TDMPC does not use n_obs_steps but the environment requires this.
|
||||
with open_dict(cfg):
|
||||
cfg['n_obs_steps'] = 1
|
||||
cfg["n_obs_steps"] = 1
|
||||
offline_buffer = make_offline_buffer(cfg)
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
policy.select_action(env.observation_spec.rand()['observation'].to(DEVICE), torch.tensor(0, device=DEVICE))
|
||||
|
||||
policy.update(offline_buffer, torch.tensor(0, device=DEVICE))
|
||||
|
||||
action = policy(
|
||||
env.observation_spec.rand()["observation"].to(DEVICE),
|
||||
torch.tensor(0, device=DEVICE),
|
||||
)
|
||||
assert action.shape == env.action_spec.shape
|
||||
|
||||
|
||||
def test_abstract_policy_forward():
|
||||
@@ -90,21 +104,20 @@ def test_abstract_policy_forward():
|
||||
|
||||
def _set_seed(self, seed: int | None):
|
||||
return
|
||||
|
||||
|
||||
class StubPolicy(AbstractPolicy):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n_action_steps = n_action_steps
|
||||
super().__init__(n_action_steps)
|
||||
self.n_policy_invocations = 0
|
||||
|
||||
def update(self):
|
||||
pass
|
||||
|
||||
def select_action(self):
|
||||
def select_actions(self):
|
||||
self.n_policy_invocations += 1
|
||||
return torch.stack([torch.tensor([i]) for i in range(self.n_action_steps)]).unsqueeze(0)
|
||||
|
||||
return torch.stack(
|
||||
[torch.tensor([i]) for i in range(self.n_action_steps)]
|
||||
).unsqueeze(0)
|
||||
|
||||
env = StubEnv()
|
||||
policy = StubPolicy()
|
||||
@@ -119,4 +132,4 @@ def test_abstract_policy_forward():
|
||||
|
||||
assert len(rollout) == terminate_at + 1 # +1 for the reset observation
|
||||
assert policy.n_policy_invocations == (terminate_at // n_action_steps) + 1
|
||||
assert torch.equal(rollout['observation'].flatten(), torch.arange(terminate_at + 1))
|
||||
assert torch.equal(rollout["observation"].flatten(), torch.arange(terminate_at + 1))
|
||||
|
||||
Reference in New Issue
Block a user