Add option for random policy
This commit is contained in:
@@ -4,9 +4,9 @@ import hydra
|
|||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tensordict import TensorDict
|
|
||||||
from tensordict.nn import TensorDictModule
|
from tensordict.nn import TensorDictModule
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
from torchrl.envs import EnvBase
|
||||||
|
|
||||||
from lerobot.common.envs.factory import make_env
|
from lerobot.common.envs.factory import make_env
|
||||||
from lerobot.common.tdmpc import TDMPC
|
from lerobot.common.tdmpc import TDMPC
|
||||||
@@ -14,7 +14,12 @@ from lerobot.common.utils import set_seed
|
|||||||
|
|
||||||
|
|
||||||
def eval_policy(
|
def eval_policy(
|
||||||
env, policy, num_episodes: int, save_video: bool = False, video_dir: Path = None
|
env: EnvBase,
|
||||||
|
policy: TensorDictModule = None,
|
||||||
|
num_episodes: int = 10,
|
||||||
|
max_steps: int = 30,
|
||||||
|
save_video: bool = False,
|
||||||
|
video_dir: Path = None,
|
||||||
):
|
):
|
||||||
rewards = []
|
rewards = []
|
||||||
successes = []
|
successes = []
|
||||||
@@ -31,7 +36,7 @@ def eval_policy(
|
|||||||
rendering_callback(env)
|
rendering_callback(env)
|
||||||
|
|
||||||
rollout = env.rollout(
|
rollout = env.rollout(
|
||||||
max_steps=30,
|
max_steps=max_steps,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
callback=rendering_callback,
|
callback=rendering_callback,
|
||||||
auto_reset=False,
|
auto_reset=False,
|
||||||
@@ -73,9 +78,10 @@ def eval(cfg: dict):
|
|||||||
out_keys=["action"],
|
out_keys=["action"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# policy can be None to rollout a random policy
|
||||||
metrics = eval_policy(
|
metrics = eval_policy(
|
||||||
env,
|
env,
|
||||||
policy,
|
policy=policy,
|
||||||
num_episodes=10,
|
num_episodes=10,
|
||||||
save_video=True,
|
save_video=True,
|
||||||
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
|
||||||
|
|||||||
Reference in New Issue
Block a user