From 937b2f8cba085c0c42460daa1d57f67180f596a2 Mon Sep 17 00:00:00 2001 From: Cadene Date: Wed, 31 Jan 2024 13:54:32 +0000 Subject: [PATCH] Add option for random policy --- lerobot/scripts/eval.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 58558928c..4137e5d06 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -4,9 +4,9 @@ import hydra import imageio import numpy as np import torch -from tensordict import TensorDict from tensordict.nn import TensorDictModule from termcolor import colored +from torchrl.envs import EnvBase from lerobot.common.envs.factory import make_env from lerobot.common.tdmpc import TDMPC @@ -14,7 +14,12 @@ from lerobot.common.utils import set_seed def eval_policy( - env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None + env: EnvBase, + policy: TensorDictModule = None, + num_episodes: int = 10, + max_steps: int = 30, + save_video: bool = False, + video_dir: Path = None, ): rewards = [] successes = [] @@ -31,7 +36,7 @@ def eval_policy( rendering_callback(env) rollout = env.rollout( - max_steps=30, + max_steps=max_steps, policy=policy, callback=rendering_callback, auto_reset=False, @@ -73,9 +78,10 @@ def eval(cfg: dict): out_keys=["action"], ) + # policy can be None to rollout a random policy metrics = eval_policy( env, - policy, + policy=policy, num_episodes=10, save_video=True, video_dir=Path("tmp/2023_01_29_xarm_lift_final"),