282 lines
12 KiB
Python
282 lines
12 KiB
Python
import argparse
|
||
import sys
|
||
import os
|
||
import time
|
||
|
||
import imageio
|
||
import numpy as np
|
||
import h5py
|
||
import torch
|
||
|
||
# 先引入 AppLauncher,并尽早实例化 SimulationApp,避免 pxr 未加载
|
||
from isaaclab.app import AppLauncher
|
||
import cli_args # isort: skip
|
||
|
||
# CLI
|
||
parser = argparse.ArgumentParser(description="Play an RL agent with multi-cam recording.")
|
||
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
|
||
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
|
||
parser.add_argument("--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations.")
|
||
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
|
||
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
|
||
parser.add_argument("--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point.")
|
||
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
|
||
parser.add_argument("--use_pretrained_checkpoint", action="store_true", help="Use the pre-trained checkpoint from Nucleus.")
|
||
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
|
||
parser.add_argument("--max_steps", type=int, default=None, help="最大步数,达到后提前退出")
|
||
cli_args.add_rsl_rl_args(parser)
|
||
AppLauncher.add_app_launcher_args(parser)
|
||
args_cli, hydra_args = parser.parse_known_args()
|
||
if args_cli.video:
|
||
args_cli.enable_cameras = True
|
||
# 将 hydra 剩余参数传递
|
||
sys.argv = [sys.argv[0]] + hydra_args
|
||
|
||
# ==== 先实例化 SimulationApp ====
|
||
app_launcher = AppLauncher(args_cli)
|
||
simulation_app = app_launcher.app
|
||
|
||
# ==== 之后再 import 依赖 isaac/pxr 的模块 ====
|
||
import gymnasium as gym
|
||
from rsl_rl.runners import DistillationRunner, OnPolicyRunner
|
||
from isaaclab.envs import DirectMARLEnv, DirectMARLEnvCfg, DirectRLEnvCfg, ManagerBasedRLEnvCfg, multi_agent_to_single_agent
|
||
from isaaclab.utils.assets import retrieve_file_path
|
||
from isaaclab.utils.dict import print_dict
|
||
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
|
||
from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
|
||
|
||
import isaaclab_tasks # noqa: F401
|
||
from isaaclab_tasks.utils import get_checkpoint_path
|
||
from isaaclab_tasks.utils.hydra import hydra_task_config
|
||
import mindbot.tasks # noqa: F401
|
||
|
||
CAM_NAMES = ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand", "cam_top", "cam_side"]
|
||
|
||
|
||
@hydra_task_config(args_cli.task, args_cli.agent)
|
||
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
|
||
task_name = args_cli.task.split(":")[-1]
|
||
train_task_name = task_name.replace("-Play", "")
|
||
|
||
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
|
||
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
|
||
env_cfg.seed = agent_cfg.seed
|
||
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
|
||
|
||
log_root_path = os.path.abspath(os.path.join("logs", "rsl_rl", agent_cfg.experiment_name))
|
||
print(f"[INFO] Loading experiment from directory: {log_root_path}")
|
||
if args_cli.use_pretrained_checkpoint:
|
||
resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
|
||
if not resume_path:
|
||
print("[INFO] No pre-trained checkpoint for this task.")
|
||
return
|
||
elif args_cli.checkpoint:
|
||
resume_path = retrieve_file_path(args_cli.checkpoint)
|
||
else:
|
||
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
|
||
|
||
log_dir = os.path.dirname(resume_path)
|
||
env_cfg.log_dir = log_dir
|
||
|
||
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
|
||
if isinstance(env.unwrapped, DirectMARLEnv):
|
||
env = multi_agent_to_single_agent(env)
|
||
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)
|
||
|
||
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
|
||
if agent_cfg.class_name == "OnPolicyRunner":
|
||
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
|
||
elif agent_cfg.class_name == "DistillationRunner":
|
||
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
|
||
else:
|
||
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
|
||
runner.load(resume_path)
|
||
|
||
policy = runner.get_inference_policy(device=env.unwrapped.device)
|
||
try:
|
||
policy_nn = runner.alg.policy
|
||
except AttributeError:
|
||
policy_nn = runner.alg.actor_critic
|
||
|
||
# 导出模型
|
||
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
|
||
normalizer = getattr(policy_nn, "actor_obs_normalizer", getattr(policy_nn, "student_obs_normalizer", None))
|
||
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
|
||
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")
|
||
|
||
# dt = env.unwrapped.step_dt
|
||
|
||
# # 录制缓冲
|
||
# cam_buffers = {n: [] for n in CAM_NAMES}
|
||
# joint_log, joint_vel_log, action_log, ts_log = [], [], [], []
|
||
# t0 = time.time()
|
||
|
||
# obs = env.get_observations()
|
||
# timestep = 0
|
||
# step_count = 0
|
||
# try:
|
||
# while simulation_app.is_running():
|
||
# start_time = time.time()
|
||
# with torch.inference_mode():
|
||
# actions = policy(obs)
|
||
# obs, _, dones, _ = env.step(actions)
|
||
# policy_nn.reset(dones)
|
||
|
||
# # 相机帧(取 env0,如需全部 env 去掉 [0])
|
||
# for name in CAM_NAMES:
|
||
# if name not in env.unwrapped.scene.sensors:
|
||
# continue
|
||
# cam = env.unwrapped.scene.sensors[name]
|
||
# rgba = cam.data.output.get("rgba", cam.data.output.get("rgb"))
|
||
# if rgba is None:
|
||
# continue
|
||
# frame = rgba[0].cpu().numpy()
|
||
# if frame.shape[-1] == 4:
|
||
# frame = frame[..., :3]
|
||
# cam_buffers[name].append(frame)
|
||
|
||
# # 关节 / 速度 / 动作
|
||
# robot = env.unwrapped.scene["Mindbot"]
|
||
# joint_log.append(robot.data.joint_pos.cpu().numpy())
|
||
# joint_vel_log.append(robot.data.joint_vel.cpu().numpy())
|
||
# action_log.append(actions.cpu().numpy())
|
||
# ts_log.append(time.time() - t0)
|
||
|
||
# step_count += 1
|
||
# if args_cli.max_steps and step_count >= args_cli.max_steps:
|
||
# break
|
||
|
||
# if args_cli.video:
|
||
# timestep += 1
|
||
# if timestep == args_cli.video_length:
|
||
# break
|
||
|
||
# sleep_time = dt - (time.time() - start_time)
|
||
# if args_cli.real_time and sleep_time > 0:
|
||
# time.sleep(sleep_time)
|
||
# finally:
|
||
# # 保存 HDF5
|
||
# h5_path = os.path.join(log_dir, "rollout_multi_cam.h5")
|
||
# with h5py.File(h5_path, "w") as f:
|
||
# f.create_dataset("joint_pos", data=np.stack(joint_log), compression="gzip")
|
||
# f.create_dataset("joint_vel", data=np.stack(joint_vel_log), compression="gzip")
|
||
# f.create_dataset("actions", data=np.stack(action_log), compression="gzip")
|
||
# f.create_dataset("timestamps", data=np.array(ts_log))
|
||
# for name, frames in cam_buffers.items():
|
||
# if not frames:
|
||
# continue
|
||
# dset = f.create_dataset(f"cams/{name}/rgb", data=np.stack(frames), compression="gzip")
|
||
# if name in ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand"]:
|
||
# fx, fy, cx, cy = 911.77, 911.5, 624.07, 364.05
|
||
# else:
|
||
# fx, fy, cx, cy = 458.7488, 458.8663, 323.3297, 240.6295
|
||
# dset.attrs["fx"] = fx
|
||
# dset.attrs["fy"] = fy
|
||
# dset.attrs["cx"] = cx
|
||
# dset.attrs["cy"] = cy
|
||
# dset.attrs["focal_length"] = 1.93
|
||
# print(f"[INFO] Saved HDF5 to {h5_path}")
|
||
|
||
# # 可选:单路 MP4
|
||
# head_frames = cam_buffers["cam_head"]
|
||
# if head_frames:
|
||
# fps = int(round(1.0 / dt))
|
||
# video_path = os.path.join(log_dir, "cam_head.mp4")
|
||
# imageio.mimsave(video_path, head_frames, fps=fps)
|
||
# print(f"[INFO] Saved video to {video_path}")
|
||
|
||
# env.close()
|
||
dt = env.unwrapped.step_dt
|
||
|
||
# 录制缓冲
|
||
cam_buffers = {n: [] for n in CAM_NAMES}
|
||
joint_log, joint_vel_log, action_log, ts_log = [], [], [], []
|
||
t0 = time.time()
|
||
|
||
obs = env.get_observations()
|
||
timestep = 0
|
||
while simulation_app.is_running():
|
||
start_time = time.time()
|
||
with torch.inference_mode():
|
||
actions = policy(obs)
|
||
obs, _, dones, _ = env.step(actions)
|
||
policy_nn.reset(dones)
|
||
|
||
# 相机帧(取 env0,如需全部 env 去掉 [0])
|
||
for name in CAM_NAMES:
|
||
if name not in env.unwrapped.scene.sensors:
|
||
continue
|
||
cam = env.unwrapped.scene.sensors[name]
|
||
rgba = cam.data.output.get("rgba", cam.data.output.get("rgb"))
|
||
if rgba is None:
|
||
continue
|
||
frame = rgba[0].cpu().numpy()
|
||
if frame.shape[-1] == 4:
|
||
frame = frame[..., :3]
|
||
cam_buffers[name].append(frame)
|
||
|
||
# 关节 / 速度 / 动作
|
||
robot = env.unwrapped.scene["Mindbot"]
|
||
joint_log.append(robot.data.joint_pos.cpu().numpy())
|
||
joint_vel_log.append(robot.data.joint_vel.cpu().numpy())
|
||
action_log.append(actions.cpu().numpy())
|
||
ts_log.append(time.time() - t0)
|
||
|
||
if args_cli.video:
|
||
timestep += 1
|
||
if timestep == args_cli.video_length:
|
||
break
|
||
|
||
sleep_time = dt - (time.time() - start_time)
|
||
if args_cli.real_time and sleep_time > 0:
|
||
time.sleep(sleep_time)
|
||
|
||
# 保存 HDF5
|
||
h5_path = os.path.join(log_dir, "rollout_multi_cam.h5")
|
||
with h5py.File(h5_path, "w") as f:
|
||
f.create_dataset("joint_pos", data=np.stack(joint_log), compression="gzip")
|
||
f.create_dataset("joint_vel", data=np.stack(joint_vel_log), compression="gzip")
|
||
f.create_dataset("actions", data=np.stack(action_log), compression="gzip")
|
||
f.create_dataset("timestamps", data=np.array(ts_log))
|
||
for name, frames in cam_buffers.items():
|
||
if not frames:
|
||
continue
|
||
dset = f.create_dataset(f"cams/{name}/rgb", data=np.stack(frames), compression="gzip")
|
||
# 内参:按你相机 vector 设置
|
||
if name in ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand"]:
|
||
fx, fy, cx, cy = 911.77, 911.5, 624.07, 364.05
|
||
else:
|
||
fx, fy, cx, cy = 458.7488, 458.8663, 323.3297, 240.6295
|
||
dset.attrs["fx"] = fx
|
||
dset.attrs["fy"] = fy
|
||
dset.attrs["cx"] = cx
|
||
dset.attrs["cy"] = cy
|
||
dset.attrs["focal_length"] = 1.93
|
||
print(f"[INFO] Saved HDF5 to {h5_path}")
|
||
|
||
# 可选:单路 MP4
|
||
head_frames = cam_buffers["cam_head"]
|
||
if head_frames:
|
||
fps = int(round(1.0 / dt))
|
||
video_path = os.path.join(log_dir, "cam_head.mp4")
|
||
imageio.mimsave(video_path, head_frames, fps=fps)
|
||
print(f"[INFO] Saved video to {video_path}")
|
||
|
||
env.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
simulation_app.close()
|
||
|
||
# import h5py
|
||
|
||
# h5_path = r"C:\Users\PC\workpalce\mindbot\logs\rsl_rl\mindbot_grasp\2026-01-15_11-51-00\rollout_multi_cam.h5"
|
||
|
||
# with h5py.File(h5_path, "r") as f:
|
||
# def walk(name, obj):
|
||
# if isinstance(obj, h5py.Dataset):
|
||
# print(f"{name}: shape={obj.shape}, dtype={obj.dtype}")
|
||
# else:
|
||
# print(f"{name}/")
|
||
# f.visititems(walk) |