Files
mindbot/test.py
2026-01-28 19:51:03 +08:00

282 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import sys
import os
import time
import imageio
import numpy as np
import h5py
import torch
# 先引入 AppLauncher并尽早实例化 SimulationApp避免 pxr 未加载
from isaaclab.app import AppLauncher
import cli_args # isort: skip
# CLI
parser = argparse.ArgumentParser(description="Play an RL agent with multi-cam recording.")
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
parser.add_argument("--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations.")
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
parser.add_argument("--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point.")
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
parser.add_argument("--use_pretrained_checkpoint", action="store_true", help="Use the pre-trained checkpoint from Nucleus.")
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
parser.add_argument("--max_steps", type=int, default=None, help="最大步数,达到后提前退出")
cli_args.add_rsl_rl_args(parser)
AppLauncher.add_app_launcher_args(parser)
args_cli, hydra_args = parser.parse_known_args()
if args_cli.video:
args_cli.enable_cameras = True
# 将 hydra 剩余参数传递
sys.argv = [sys.argv[0]] + hydra_args
# ==== 先实例化 SimulationApp ====
app_launcher = AppLauncher(args_cli)
simulation_app = app_launcher.app
# ==== 之后再 import 依赖 isaac/pxr 的模块 ====
import gymnasium as gym
from rsl_rl.runners import DistillationRunner, OnPolicyRunner
from isaaclab.envs import DirectMARLEnv, DirectMARLEnvCfg, DirectRLEnvCfg, ManagerBasedRLEnvCfg, multi_agent_to_single_agent
from isaaclab.utils.assets import retrieve_file_path
from isaaclab.utils.dict import print_dict
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
import isaaclab_tasks # noqa: F401
from isaaclab_tasks.utils import get_checkpoint_path
from isaaclab_tasks.utils.hydra import hydra_task_config
import mindbot.tasks # noqa: F401
CAM_NAMES = ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand", "cam_top", "cam_side"]
@hydra_task_config(args_cli.task, args_cli.agent)
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
task_name = args_cli.task.split(":")[-1]
train_task_name = task_name.replace("-Play", "")
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
env_cfg.seed = agent_cfg.seed
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
log_root_path = os.path.abspath(os.path.join("logs", "rsl_rl", agent_cfg.experiment_name))
print(f"[INFO] Loading experiment from directory: {log_root_path}")
if args_cli.use_pretrained_checkpoint:
resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
if not resume_path:
print("[INFO] No pre-trained checkpoint for this task.")
return
elif args_cli.checkpoint:
resume_path = retrieve_file_path(args_cli.checkpoint)
else:
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
log_dir = os.path.dirname(resume_path)
env_cfg.log_dir = log_dir
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
if isinstance(env.unwrapped, DirectMARLEnv):
env = multi_agent_to_single_agent(env)
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
if agent_cfg.class_name == "OnPolicyRunner":
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
elif agent_cfg.class_name == "DistillationRunner":
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
else:
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
runner.load(resume_path)
policy = runner.get_inference_policy(device=env.unwrapped.device)
try:
policy_nn = runner.alg.policy
except AttributeError:
policy_nn = runner.alg.actor_critic
# 导出模型
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
normalizer = getattr(policy_nn, "actor_obs_normalizer", getattr(policy_nn, "student_obs_normalizer", None))
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")
# dt = env.unwrapped.step_dt
# # 录制缓冲
# cam_buffers = {n: [] for n in CAM_NAMES}
# joint_log, joint_vel_log, action_log, ts_log = [], [], [], []
# t0 = time.time()
# obs = env.get_observations()
# timestep = 0
# step_count = 0
# try:
# while simulation_app.is_running():
# start_time = time.time()
# with torch.inference_mode():
# actions = policy(obs)
# obs, _, dones, _ = env.step(actions)
# policy_nn.reset(dones)
# # 相机帧(取 env0如需全部 env 去掉 [0]
# for name in CAM_NAMES:
# if name not in env.unwrapped.scene.sensors:
# continue
# cam = env.unwrapped.scene.sensors[name]
# rgba = cam.data.output.get("rgba", cam.data.output.get("rgb"))
# if rgba is None:
# continue
# frame = rgba[0].cpu().numpy()
# if frame.shape[-1] == 4:
# frame = frame[..., :3]
# cam_buffers[name].append(frame)
# # 关节 / 速度 / 动作
# robot = env.unwrapped.scene["Mindbot"]
# joint_log.append(robot.data.joint_pos.cpu().numpy())
# joint_vel_log.append(robot.data.joint_vel.cpu().numpy())
# action_log.append(actions.cpu().numpy())
# ts_log.append(time.time() - t0)
# step_count += 1
# if args_cli.max_steps and step_count >= args_cli.max_steps:
# break
# if args_cli.video:
# timestep += 1
# if timestep == args_cli.video_length:
# break
# sleep_time = dt - (time.time() - start_time)
# if args_cli.real_time and sleep_time > 0:
# time.sleep(sleep_time)
# finally:
# # 保存 HDF5
# h5_path = os.path.join(log_dir, "rollout_multi_cam.h5")
# with h5py.File(h5_path, "w") as f:
# f.create_dataset("joint_pos", data=np.stack(joint_log), compression="gzip")
# f.create_dataset("joint_vel", data=np.stack(joint_vel_log), compression="gzip")
# f.create_dataset("actions", data=np.stack(action_log), compression="gzip")
# f.create_dataset("timestamps", data=np.array(ts_log))
# for name, frames in cam_buffers.items():
# if not frames:
# continue
# dset = f.create_dataset(f"cams/{name}/rgb", data=np.stack(frames), compression="gzip")
# if name in ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand"]:
# fx, fy, cx, cy = 911.77, 911.5, 624.07, 364.05
# else:
# fx, fy, cx, cy = 458.7488, 458.8663, 323.3297, 240.6295
# dset.attrs["fx"] = fx
# dset.attrs["fy"] = fy
# dset.attrs["cx"] = cx
# dset.attrs["cy"] = cy
# dset.attrs["focal_length"] = 1.93
# print(f"[INFO] Saved HDF5 to {h5_path}")
# # 可选:单路 MP4
# head_frames = cam_buffers["cam_head"]
# if head_frames:
# fps = int(round(1.0 / dt))
# video_path = os.path.join(log_dir, "cam_head.mp4")
# imageio.mimsave(video_path, head_frames, fps=fps)
# print(f"[INFO] Saved video to {video_path}")
# env.close()
dt = env.unwrapped.step_dt
# 录制缓冲
cam_buffers = {n: [] for n in CAM_NAMES}
joint_log, joint_vel_log, action_log, ts_log = [], [], [], []
t0 = time.time()
obs = env.get_observations()
timestep = 0
while simulation_app.is_running():
start_time = time.time()
with torch.inference_mode():
actions = policy(obs)
obs, _, dones, _ = env.step(actions)
policy_nn.reset(dones)
# 相机帧(取 env0如需全部 env 去掉 [0]
for name in CAM_NAMES:
if name not in env.unwrapped.scene.sensors:
continue
cam = env.unwrapped.scene.sensors[name]
rgba = cam.data.output.get("rgba", cam.data.output.get("rgb"))
if rgba is None:
continue
frame = rgba[0].cpu().numpy()
if frame.shape[-1] == 4:
frame = frame[..., :3]
cam_buffers[name].append(frame)
# 关节 / 速度 / 动作
robot = env.unwrapped.scene["Mindbot"]
joint_log.append(robot.data.joint_pos.cpu().numpy())
joint_vel_log.append(robot.data.joint_vel.cpu().numpy())
action_log.append(actions.cpu().numpy())
ts_log.append(time.time() - t0)
if args_cli.video:
timestep += 1
if timestep == args_cli.video_length:
break
sleep_time = dt - (time.time() - start_time)
if args_cli.real_time and sleep_time > 0:
time.sleep(sleep_time)
# 保存 HDF5
h5_path = os.path.join(log_dir, "rollout_multi_cam.h5")
with h5py.File(h5_path, "w") as f:
f.create_dataset("joint_pos", data=np.stack(joint_log), compression="gzip")
f.create_dataset("joint_vel", data=np.stack(joint_vel_log), compression="gzip")
f.create_dataset("actions", data=np.stack(action_log), compression="gzip")
f.create_dataset("timestamps", data=np.array(ts_log))
for name, frames in cam_buffers.items():
if not frames:
continue
dset = f.create_dataset(f"cams/{name}/rgb", data=np.stack(frames), compression="gzip")
# 内参:按你相机 vector 设置
if name in ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand"]:
fx, fy, cx, cy = 911.77, 911.5, 624.07, 364.05
else:
fx, fy, cx, cy = 458.7488, 458.8663, 323.3297, 240.6295
dset.attrs["fx"] = fx
dset.attrs["fy"] = fy
dset.attrs["cx"] = cx
dset.attrs["cy"] = cy
dset.attrs["focal_length"] = 1.93
print(f"[INFO] Saved HDF5 to {h5_path}")
# 可选:单路 MP4
head_frames = cam_buffers["cam_head"]
if head_frames:
fps = int(round(1.0 / dt))
video_path = os.path.join(log_dir, "cam_head.mp4")
imageio.mimsave(video_path, head_frames, fps=fps)
print(f"[INFO] Saved video to {video_path}")
env.close()
if __name__ == "__main__":
main()
simulation_app.close()
# import h5py
# h5_path = r"C:\Users\PC\workpalce\mindbot\logs\rsl_rl\mindbot_grasp\2026-01-15_11-51-00\rollout_multi_cam.h5"
# with h5py.File(h5_path, "r") as f:
# def walk(name, obj):
# if isinstance(obj, h5py.Dataset):
# print(f"{name}: shape={obj.shape}, dtype={obj.dtype}")
# else:
# print(f"{name}/")
# f.visititems(walk)