本地4090代码提交
This commit is contained in:
270
scripts/rsl_rl/play.py.bak1
Normal file
270
scripts/rsl_rl/play.py.bak1
Normal file
@@ -0,0 +1,270 @@
|
||||
import argparse
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
import h5py
|
||||
import torch
|
||||
|
||||
# 先引入 AppLauncher,并尽早实例化 SimulationApp,避免 pxr 未加载
|
||||
from isaaclab.app import AppLauncher
|
||||
import cli_args # isort: skip
|
||||
|
||||
# CLI
|
||||
parser = argparse.ArgumentParser(description="Play an RL agent with multi-cam recording.")
|
||||
parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
|
||||
parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
|
||||
parser.add_argument("--disable_fabric", action="store_true", default=False, help="Disable fabric and use USD I/O operations.")
|
||||
parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
|
||||
parser.add_argument("--task", type=str, default=None, help="Name of the task.")
|
||||
parser.add_argument("--agent", type=str, default="rsl_rl_cfg_entry_point", help="Name of the RL agent configuration entry point.")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
|
||||
parser.add_argument("--use_pretrained_checkpoint", action="store_true", help="Use the pre-trained checkpoint from Nucleus.")
|
||||
parser.add_argument("--real-time", action="store_true", default=False, help="Run in real-time, if possible.")
|
||||
parser.add_argument("--max_steps", type=int, default=None, help="最大步数,达到后提前退出")
|
||||
cli_args.add_rsl_rl_args(parser)
|
||||
AppLauncher.add_app_launcher_args(parser)
|
||||
args_cli, hydra_args = parser.parse_known_args()
|
||||
if args_cli.video:
|
||||
args_cli.enable_cameras = True
|
||||
# 将 hydra 剩余参数传递
|
||||
sys.argv = [sys.argv[0]] + hydra_args
|
||||
|
||||
# ==== 先实例化 SimulationApp ====
|
||||
app_launcher = AppLauncher(args_cli)
|
||||
simulation_app = app_launcher.app
|
||||
|
||||
# ==== 之后再 import 依赖 isaac/pxr 的模块 ====
|
||||
import gymnasium as gym
|
||||
from rsl_rl.runners import DistillationRunner, OnPolicyRunner
|
||||
from isaaclab.envs import DirectMARLEnv, DirectMARLEnvCfg, DirectRLEnvCfg, ManagerBasedRLEnvCfg, multi_agent_to_single_agent
|
||||
from isaaclab.utils.assets import retrieve_file_path
|
||||
from isaaclab.utils.dict import print_dict
|
||||
from isaaclab_rl.rsl_rl import RslRlBaseRunnerCfg, RslRlVecEnvWrapper, export_policy_as_jit, export_policy_as_onnx
|
||||
from isaaclab_rl.utils.pretrained_checkpoint import get_published_pretrained_checkpoint
|
||||
|
||||
import isaaclab_tasks # noqa: F401
|
||||
from isaaclab_tasks.utils import get_checkpoint_path
|
||||
from isaaclab_tasks.utils.hydra import hydra_task_config
|
||||
import mindbot.tasks # noqa: F401
|
||||
|
||||
CAM_NAMES = ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand", "cam_top", "cam_side"]
|
||||
|
||||
|
||||
@hydra_task_config(args_cli.task, args_cli.agent)
|
||||
def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: RslRlBaseRunnerCfg):
|
||||
task_name = args_cli.task.split(":")[-1]
|
||||
train_task_name = task_name.replace("-Play", "")
|
||||
|
||||
agent_cfg: RslRlBaseRunnerCfg = cli_args.update_rsl_rl_cfg(agent_cfg, args_cli)
|
||||
env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
|
||||
env_cfg.seed = agent_cfg.seed
|
||||
env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
|
||||
|
||||
log_root_path = os.path.abspath(os.path.join("logs", "rsl_rl", agent_cfg.experiment_name))
|
||||
print(f"[INFO] Loading experiment from directory: {log_root_path}")
|
||||
if args_cli.use_pretrained_checkpoint:
|
||||
resume_path = get_published_pretrained_checkpoint("rsl_rl", train_task_name)
|
||||
if not resume_path:
|
||||
print("[INFO] No pre-trained checkpoint for this task.")
|
||||
return
|
||||
elif args_cli.checkpoint:
|
||||
resume_path = retrieve_file_path(args_cli.checkpoint)
|
||||
else:
|
||||
resume_path = get_checkpoint_path(log_root_path, agent_cfg.load_run, agent_cfg.load_checkpoint)
|
||||
|
||||
log_dir = os.path.dirname(resume_path)
|
||||
env_cfg.log_dir = log_dir
|
||||
|
||||
env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
|
||||
if isinstance(env.unwrapped, DirectMARLEnv):
|
||||
env = multi_agent_to_single_agent(env)
|
||||
env = RslRlVecEnvWrapper(env, clip_actions=agent_cfg.clip_actions)
|
||||
|
||||
print(f"[INFO]: Loading model checkpoint from: {resume_path}")
|
||||
if agent_cfg.class_name == "OnPolicyRunner":
|
||||
runner = OnPolicyRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
|
||||
elif agent_cfg.class_name == "DistillationRunner":
|
||||
runner = DistillationRunner(env, agent_cfg.to_dict(), log_dir=None, device=agent_cfg.device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported runner class: {agent_cfg.class_name}")
|
||||
runner.load(resume_path)
|
||||
|
||||
policy = runner.get_inference_policy(device=env.unwrapped.device)
|
||||
try:
|
||||
policy_nn = runner.alg.policy
|
||||
except AttributeError:
|
||||
policy_nn = runner.alg.actor_critic
|
||||
|
||||
# 导出模型
|
||||
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
|
||||
normalizer = getattr(policy_nn, "actor_obs_normalizer", getattr(policy_nn, "student_obs_normalizer", None))
|
||||
export_policy_as_jit(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.pt")
|
||||
export_policy_as_onnx(policy_nn, normalizer=normalizer, path=export_model_dir, filename="policy.onnx")
|
||||
|
||||
# dt = env.unwrapped.step_dt
|
||||
|
||||
# # 录制缓冲
|
||||
# cam_buffers = {n: [] for n in CAM_NAMES}
|
||||
# joint_log, joint_vel_log, action_log, ts_log = [], [], [], []
|
||||
# t0 = time.time()
|
||||
|
||||
# obs = env.get_observations()
|
||||
# timestep = 0
|
||||
# step_count = 0
|
||||
# try:
|
||||
# while simulation_app.is_running():
|
||||
# start_time = time.time()
|
||||
# with torch.inference_mode():
|
||||
# actions = policy(obs)
|
||||
# obs, _, dones, _ = env.step(actions)
|
||||
# policy_nn.reset(dones)
|
||||
|
||||
# # 相机帧(取 env0,如需全部 env 去掉 [0])
|
||||
# for name in CAM_NAMES:
|
||||
# if name not in env.unwrapped.scene.sensors:
|
||||
# continue
|
||||
# cam = env.unwrapped.scene.sensors[name]
|
||||
# rgba = cam.data.output.get("rgba", cam.data.output.get("rgb"))
|
||||
# if rgba is None:
|
||||
# continue
|
||||
# frame = rgba[0].cpu().numpy()
|
||||
# if frame.shape[-1] == 4:
|
||||
# frame = frame[..., :3]
|
||||
# cam_buffers[name].append(frame)
|
||||
|
||||
# # 关节 / 速度 / 动作
|
||||
# robot = env.unwrapped.scene["Mindbot"]
|
||||
# joint_log.append(robot.data.joint_pos.cpu().numpy())
|
||||
# joint_vel_log.append(robot.data.joint_vel.cpu().numpy())
|
||||
# action_log.append(actions.cpu().numpy())
|
||||
# ts_log.append(time.time() - t0)
|
||||
|
||||
# step_count += 1
|
||||
# if args_cli.max_steps and step_count >= args_cli.max_steps:
|
||||
# break
|
||||
|
||||
# if args_cli.video:
|
||||
# timestep += 1
|
||||
# if timestep == args_cli.video_length:
|
||||
# break
|
||||
|
||||
# sleep_time = dt - (time.time() - start_time)
|
||||
# if args_cli.real_time and sleep_time > 0:
|
||||
# time.sleep(sleep_time)
|
||||
# finally:
|
||||
# # 保存 HDF5
|
||||
# h5_path = os.path.join(log_dir, "rollout_multi_cam.h5")
|
||||
# with h5py.File(h5_path, "w") as f:
|
||||
# f.create_dataset("joint_pos", data=np.stack(joint_log), compression="gzip")
|
||||
# f.create_dataset("joint_vel", data=np.stack(joint_vel_log), compression="gzip")
|
||||
# f.create_dataset("actions", data=np.stack(action_log), compression="gzip")
|
||||
# f.create_dataset("timestamps", data=np.array(ts_log))
|
||||
# for name, frames in cam_buffers.items():
|
||||
# if not frames:
|
||||
# continue
|
||||
# dset = f.create_dataset(f"cams/{name}/rgb", data=np.stack(frames), compression="gzip")
|
||||
# if name in ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand"]:
|
||||
# fx, fy, cx, cy = 911.77, 911.5, 624.07, 364.05
|
||||
# else:
|
||||
# fx, fy, cx, cy = 458.7488, 458.8663, 323.3297, 240.6295
|
||||
# dset.attrs["fx"] = fx
|
||||
# dset.attrs["fy"] = fy
|
||||
# dset.attrs["cx"] = cx
|
||||
# dset.attrs["cy"] = cy
|
||||
# dset.attrs["focal_length"] = 1.93
|
||||
# print(f"[INFO] Saved HDF5 to {h5_path}")
|
||||
|
||||
# # 可选:单路 MP4
|
||||
# head_frames = cam_buffers["cam_head"]
|
||||
# if head_frames:
|
||||
# fps = int(round(1.0 / dt))
|
||||
# video_path = os.path.join(log_dir, "cam_head.mp4")
|
||||
# imageio.mimsave(video_path, head_frames, fps=fps)
|
||||
# print(f"[INFO] Saved video to {video_path}")
|
||||
|
||||
# env.close()
|
||||
dt = env.unwrapped.step_dt
|
||||
|
||||
# 录制缓冲
|
||||
cam_buffers = {n: [] for n in CAM_NAMES}
|
||||
joint_log, joint_vel_log, action_log, ts_log = [], [], [], []
|
||||
t0 = time.time()
|
||||
|
||||
obs = env.get_observations()
|
||||
timestep = 0
|
||||
while simulation_app.is_running():
|
||||
start_time = time.time()
|
||||
with torch.inference_mode():
|
||||
actions = policy(obs)
|
||||
obs, _, dones, _ = env.step(actions)
|
||||
policy_nn.reset(dones)
|
||||
|
||||
# 相机帧(取 env0,如需全部 env 去掉 [0])
|
||||
for name in CAM_NAMES:
|
||||
if name not in env.unwrapped.scene.sensors:
|
||||
continue
|
||||
cam = env.unwrapped.scene.sensors[name]
|
||||
rgba = cam.data.output.get("rgba", cam.data.output.get("rgb"))
|
||||
if rgba is None:
|
||||
continue
|
||||
frame = rgba[0].cpu().numpy()
|
||||
if frame.shape[-1] == 4:
|
||||
frame = frame[..., :3]
|
||||
cam_buffers[name].append(frame)
|
||||
|
||||
# 关节 / 速度 / 动作
|
||||
robot = env.unwrapped.scene["Mindbot"]
|
||||
joint_log.append(robot.data.joint_pos.cpu().numpy())
|
||||
joint_vel_log.append(robot.data.joint_vel.cpu().numpy())
|
||||
action_log.append(actions.cpu().numpy())
|
||||
ts_log.append(time.time() - t0)
|
||||
|
||||
if args_cli.video:
|
||||
timestep += 1
|
||||
if timestep == args_cli.video_length:
|
||||
break
|
||||
|
||||
sleep_time = dt - (time.time() - start_time)
|
||||
if args_cli.real_time and sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
# 保存 HDF5
|
||||
h5_path = os.path.join(log_dir, "rollout_multi_cam.h5")
|
||||
with h5py.File(h5_path, "w") as f:
|
||||
f.create_dataset("joint_pos", data=np.stack(joint_log), compression="gzip")
|
||||
f.create_dataset("joint_vel", data=np.stack(joint_vel_log), compression="gzip")
|
||||
f.create_dataset("actions", data=np.stack(action_log), compression="gzip")
|
||||
f.create_dataset("timestamps", data=np.array(ts_log))
|
||||
for name, frames in cam_buffers.items():
|
||||
if not frames:
|
||||
continue
|
||||
dset = f.create_dataset(f"cams/{name}/rgb", data=np.stack(frames), compression="gzip")
|
||||
# 内参:按你相机 vector 设置
|
||||
if name in ["cam_head", "cam_chest", "cam_left_hand", "cam_right_hand"]:
|
||||
fx, fy, cx, cy = 911.77, 911.5, 624.07, 364.05
|
||||
else:
|
||||
fx, fy, cx, cy = 458.7488, 458.8663, 323.3297, 240.6295
|
||||
dset.attrs["fx"] = fx
|
||||
dset.attrs["fy"] = fy
|
||||
dset.attrs["cx"] = cx
|
||||
dset.attrs["cy"] = cy
|
||||
dset.attrs["focal_length"] = 1.93
|
||||
print(f"[INFO] Saved HDF5 to {h5_path}")
|
||||
|
||||
# 可选:单路 MP4
|
||||
head_frames = cam_buffers["cam_head"]
|
||||
if head_frames:
|
||||
fps = int(round(1.0 / dt))
|
||||
video_path = os.path.join(log_dir, "cam_head.mp4")
|
||||
imageio.mimsave(video_path, head_frames, fps=fps)
|
||||
print(f"[INFO] Saved video to {video_path}")
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
simulation_app.close()
|
||||
Reference in New Issue
Block a user