Files
lerobot_piper/lerobot/scripts/eval.py

96 lines
2.6 KiB
Python

from pathlib import Path
import hydra
import imageio
import numpy as np
import torch
from tensordict.nn import TensorDictModule
from termcolor import colored
from torchrl.envs import EnvBase
from lerobot.common.envs.factory import make_env
from lerobot.common.tdmpc import TDMPC
from lerobot.common.utils import set_seed
def eval_policy(
env: EnvBase,
policy: TensorDictModule = None,
num_episodes: int = 10,
max_steps: int = 30,
save_video: bool = False,
video_dir: Path = None,
):
rewards = []
successes = []
for i in range(num_episodes):
ep_frames = []
def rendering_callback(env, td=None):
nonlocal ep_frames
frame = env.render()
ep_frames.append(frame)
tensordict = env.reset()
if save_video:
# render first frame before rollout
rendering_callback(env)
rollout = env.rollout(
max_steps=max_steps,
policy=policy,
callback=rendering_callback if save_video else None,
auto_reset=False,
tensordict=tensordict,
)
# print(", ".join([f"{x:.3f}" for x in rollout["next", "reward"][:,0].tolist()]))
ep_reward = rollout["next", "reward"].sum()
ep_success = rollout["next", "success"].any()
rewards.append(ep_reward.item())
successes.append(ep_success.item())
if save_video:
video_dir.mkdir(parents=True, exist_ok=True)
# TODO(rcadene): make fps configurable
video_path = video_dir / f"eval_episode_{i}.mp4"
imageio.mimsave(video_path, np.stack(ep_frames), fps=15)
metrics = {
"avg_reward": np.nanmean(rewards),
"pc_success": np.nanmean(successes) * 100,
}
return metrics
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def eval(cfg: dict):
assert torch.cuda.is_available()
set_seed(cfg.seed)
print(colored("Log dir:", "yellow", attrs=["bold"]), cfg.log_dir)
env = make_env(cfg)
policy = TDMPC(cfg)
# ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/offline.pt"
ckpt_path = "/home/rcadene/code/fowm/logs/xarm_lift/all/default/2/models/final.pt"
policy.load(ckpt_path)
policy = TensorDictModule(
policy,
in_keys=["observation", "step_count"],
out_keys=["action"],
)
# policy can be None to rollout a random policy
metrics = eval_policy(
env,
policy=policy,
num_episodes=20,
save_video=False,
video_dir=Path("tmp/2023_01_29_xarm_lift_final"),
)
print(metrics)
if __name__ == "__main__":
eval()