Refactor eval.py (#127)

This commit is contained in:
Alexander Soare
2024-05-03 17:33:16 +01:00
committed by GitHub
parent b7b69fcc3d
commit bccee745c3
12 changed files with 457 additions and 298 deletions

View File

@@ -37,7 +37,7 @@ def test_factory(env_name):
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
env = make_env(cfg, num_parallel_envs=1)
env = make_env(cfg, n_envs=1)
obs, _ = env.reset()
obs = preprocess_observation(obs)

View File

@@ -8,7 +8,7 @@ from lerobot import available_policies
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
from lerobot.common.policies.normalize import Normalize, Unnormalize
from lerobot.common.policies.policy_protocol import Policy
@@ -80,7 +80,7 @@ def test_policy(env_name, policy_name, extra_overrides):
assert isinstance(policy, PyTorchModelHubMixin)
# Check that we run select_actions and get the appropriate output.
env = make_env(cfg, num_parallel_envs=2)
env = make_env(cfg, n_envs=2)
dataloader = torch.utils.data.DataLoader(
dataset,
@@ -112,10 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# get the next action for the environment
with torch.inference_mode():
action = policy.select_action(observation)
# convert action to cpu numpy array
action = postprocess_action(action)
action = policy.select_action(observation).cpu().numpy()
# Test step through policy
env.step(action)