Add option for random policy

This commit is contained in:
Cadene
2024-01-31 13:54:32 +00:00
parent 5a5b190f70
commit 937b2f8cba

View File

@@ -4,9 +4,9 @@ import hydra
import imageio import imageio
import numpy as np import numpy as np
import torch import torch
from tensordict import TensorDict
from tensordict.nn import TensorDictModule from tensordict.nn import TensorDictModule
from termcolor import colored from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC from lerobot.common.tdmpc import TDMPC
@@ -14,7 +14,12 @@ from lerobot.common.utils import set_seed
def eval_policy( def eval_policy(
env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None env: EnvBase,
policy: TensorDictModule = None,
num_episodes: int = 10,
max_steps: int = 30,
save_video: bool = False,
video_dir: Path = None,
): ):
rewards = [] rewards = []
successes = [] successes = []
@@ -31,7 +36,7 @@ def eval_policy(
rendering_callback(env) rendering_callback(env)
rollout = env.rollout( rollout = env.rollout(
max_steps=30, max_steps=max_steps,
policy=policy, policy=policy,
callback=rendering_callback, callback=rendering_callback,
auto_reset=False, auto_reset=False,
@@ -73,9 +78,10 @@ def eval(cfg: dict):
out_keys=["action"], out_keys=["action"],
) )
# policy can be None to rollout a random policy
metrics = eval_policy( metrics = eval_policy(
env, env,
policy, policy=policy,
num_episodes=10, num_episodes=10,
save_video=True, save_video=True,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"), video_dir=Path("tmp/2023_01_29_xarm_lift_final"),