diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 75173e3f..74780b10 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -49,6 +49,7 @@ from torchrl.envs.batched_envs import BatchedEnvBase from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env +from lerobot.common.envs.simxarm.env import SimxarmEnv from lerobot.common.logger import log_output_dir from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy @@ -86,7 +87,10 @@ def eval_policy( def maybe_render_frame(env: EnvBase, _): if save_video or (return_first_video and i == 0): # noqa: B023 - ep_frames.append(env.render(mode="visualization")) # noqa: B023 + # HACK + # TODO(aliberts): set render_mode for all envs + render_mode = "visualization" if isinstance(env, SimxarmEnv) else "rgb_array" + ep_frames.append(env.render(mode=render_mode)) # noqa: B023 # Clear the policy's action queue before the start of a new rollout. if policy is not None: