backup wip
This commit is contained in:
@@ -4,11 +4,13 @@ import torch
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.policies.protocol import Policy
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.utils import init_hydra_config
|
||||
from .utils import DEVICE, DEFAULT_CONFIG_PATH
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
@@ -27,6 +29,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
"""
|
||||
Tests:
|
||||
- Making the policy object.
|
||||
- Checking that the policy follows the correct protocol.
|
||||
- Updating the policy.
|
||||
- Using the policy to select actions at inference time.
|
||||
- Test the action can be applied to the policy
|
||||
@@ -38,10 +41,14 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
f"policy={policy_name}",
|
||||
f"device={DEVICE}",
|
||||
]
|
||||
+ extra_overrides
|
||||
+ extra_overrides,
|
||||
)
|
||||
# Check that we can make the policy object.
|
||||
policy = make_policy(cfg)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
), f"The policy does not follow the required protocol. Please see {Policy.__module__}.{Policy.__name__}."
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
dataset = make_dataset(cfg)
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
@@ -62,7 +69,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||
|
||||
# Test updating the policy
|
||||
policy(batch, step=0)
|
||||
policy.update(batch, step=0)
|
||||
|
||||
# reset the policy and environment
|
||||
policy.reset()
|
||||
@@ -83,4 +90,3 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user