Refactor eval.py (#127)
This commit is contained in:
@@ -37,7 +37,7 @@ def test_factory(env_name):
|
||||
overrides=[f"env={env_name}", f"device={DEVICE}"],
|
||||
)
|
||||
|
||||
env = make_env(cfg, num_parallel_envs=1)
|
||||
env = make_env(cfg, n_envs=1)
|
||||
obs, _ = env.reset()
|
||||
obs = preprocess_observation(obs)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from lerobot import available_policies
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.utils import cycle
|
||||
from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
|
||||
from lerobot.common.envs.utils import preprocess_observation
|
||||
from lerobot.common.policies.factory import get_policy_and_config_classes, make_policy
|
||||
from lerobot.common.policies.normalize import Normalize, Unnormalize
|
||||
from lerobot.common.policies.policy_protocol import Policy
|
||||
@@ -80,7 +80,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
assert isinstance(policy, PyTorchModelHubMixin)
|
||||
|
||||
# Check that we run select_actions and get the appropriate output.
|
||||
env = make_env(cfg, num_parallel_envs=2)
|
||||
env = make_env(cfg, n_envs=2)
|
||||
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -112,10 +112,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# get the next action for the environment
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# convert action to cpu numpy array
|
||||
action = postprocess_action(action)
|
||||
action = policy.select_action(observation).cpu().numpy()
|
||||
|
||||
# Test step through policy
|
||||
env.step(action)
|
||||
|
||||
Reference in New Issue
Block a user