backup wip

This commit is contained in:
Alexander Soare
2024-04-16 16:31:44 +01:00
parent 43a614c173
commit 23be5e1e7b
4 changed files with 24 additions and 17 deletions

View File

@@ -4,11 +4,13 @@ import torch
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:
- Making the policy object.
- Checking that the policy follows the correct protocol.
- Updating the policy.
- Using the policy to select actions at inference time.
- Test the action can be applied to the policy
@@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides):
f"policy={policy_name}",
f"device={DEVICE}",
]
+ extra_overrides
+ extra_overrides,
)
# Check that we can make the policy object.
policy = make_policy(cfg)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
# Check that we run select_actions and get the appropriate output.
dataset = make_dataset(cfg)
env = make_env(cfg, num_parallel_envs=2)
@@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides):
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy
policy(batch, step=0)
policy.update(batch, step=0)
# reset the policy and environment
policy.reset()
@@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides):
# Test step through policy
env.step(action)