forked from tangger/lerobot
Compare commits
18 Commits
main
...
recovered-
| Author | SHA1 | Date | |
|---|---|---|---|
| ef45ea9649 | |||
| bc351a0134 | |||
| 68986f6fc0 | |||
| 2f124e34de | |||
| c28e774234 | |||
| 80b1a97e4c | |||
| f4fec8f51c | |||
| f4f82c916f | |||
| ecbe154709 | |||
| d00c154db9 | |||
| 55f284b306 | |||
| cf8df17d3a | |||
| e079566597 | |||
| 83d6419d70 | |||
| a0ec9e1cb1 | |||
| 3eede4447d | |||
| 9c6a7d9701 | |||
| 7b201773f3 |
@@ -58,7 +58,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
|
|||||||
log_dt("dt", dt_s)
|
log_dt("dt", dt_s)
|
||||||
|
|
||||||
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
|
||||||
if not robot.robot_type.startswith("stretch"):
|
if not robot.robot_type.startswith(("stretch", "realman")):
|
||||||
for name in robot.leader_arms:
|
for name in robot.leader_arms:
|
||||||
key = f"read_leader_{name}_pos_dt_s"
|
key = f"read_leader_{name}_pos_dt_s"
|
||||||
if key in robot.logs:
|
if key in robot.logs:
|
||||||
|
|||||||
@@ -39,3 +39,12 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
|
|||||||
port: str
|
port: str
|
||||||
motors: dict[str, tuple[int, str]]
|
motors: dict[str, tuple[int, str]]
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@MotorsBusConfig.register_subclass("realman")
|
||||||
|
@dataclass
|
||||||
|
class RealmanMotorsBusConfig(MotorsBusConfig):
|
||||||
|
ip: str
|
||||||
|
port: int
|
||||||
|
motors: dict[str, tuple[int, str]]
|
||||||
|
init_joint: dict[str, list]
|
||||||
150
lerobot/common/robot_devices/motors/realman.py
Normal file
150
lerobot/common/robot_devices/motors/realman.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
import time
|
||||||
|
from typing import Dict
|
||||||
|
from lerobot.common.robot_devices.motors.configs import RealmanMotorsBusConfig
|
||||||
|
from Robotic_Arm.rm_robot_interface import *
|
||||||
|
|
||||||
|
|
||||||
|
class RealmanMotorsBus:
|
||||||
|
"""
|
||||||
|
对Realman SDK的二次封装
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
config: RealmanMotorsBusConfig):
|
||||||
|
self.rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||||
|
self.handle = self.rmarm.rm_create_robot_arm(config.ip, config.port)
|
||||||
|
self.motors = config.motors
|
||||||
|
self.init_joint_position = config.init_joint['joint'] # [6 joints + 1 gripper]
|
||||||
|
self.safe_disable_position = config.init_joint['joint']
|
||||||
|
self.rmarm.rm_movej(self.init_joint_position[:-1], 5, 0, 0, 1)
|
||||||
|
time.sleep(3)
|
||||||
|
ret = self.rmarm.rm_get_current_arm_state()
|
||||||
|
self.init_pose = ret[1]['pose']
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_names(self) -> list[str]:
|
||||||
|
return list(self.motors.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_models(self) -> list[str]:
|
||||||
|
return [model for _, model in self.motors.values()]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_indices(self) -> list[int]:
|
||||||
|
return [idx for idx, _ in self.motors.values()]
|
||||||
|
|
||||||
|
|
||||||
|
def connect(self, enable=True) -> bool:
|
||||||
|
'''
|
||||||
|
使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序
|
||||||
|
'''
|
||||||
|
enable_flag = False
|
||||||
|
loop_flag = False
|
||||||
|
# 设置超时时间(秒)
|
||||||
|
timeout = 5
|
||||||
|
# 记录进入循环前的时间
|
||||||
|
start_time = time.time()
|
||||||
|
elapsed_time_flag = False
|
||||||
|
|
||||||
|
while not loop_flag:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
print("--------------------")
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
# 获取机械臂状态
|
||||||
|
ret = self.rmarm.rm_get_current_arm_state()
|
||||||
|
if ret[0] == 0: # 成功获取状态
|
||||||
|
enable_flag = True
|
||||||
|
else:
|
||||||
|
enable_flag = False
|
||||||
|
# 断开所有连接,销毁线程
|
||||||
|
RoboticArm.rm_destory()
|
||||||
|
print("使能状态:", enable_flag)
|
||||||
|
print("--------------------")
|
||||||
|
if(enable_flag == enable):
|
||||||
|
loop_flag = True
|
||||||
|
enable_flag = True
|
||||||
|
else:
|
||||||
|
loop_flag = False
|
||||||
|
enable_flag = False
|
||||||
|
# 检查是否超过超时时间
|
||||||
|
if elapsed_time > timeout:
|
||||||
|
print("超时....")
|
||||||
|
elapsed_time_flag = True
|
||||||
|
enable_flag = True
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
resp = enable_flag
|
||||||
|
print(f"Returning response: {resp}")
|
||||||
|
return resp
|
||||||
|
|
||||||
|
def motor_names(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def set_calibration(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def revert_calibration(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
def apply_calibration(self):
|
||||||
|
"""
|
||||||
|
移动到初始位置
|
||||||
|
"""
|
||||||
|
self.write(target_joint=self.init_joint_position)
|
||||||
|
|
||||||
|
def write(self, target_joint:list):
|
||||||
|
# self.rmarm.rm_movej(target_joint[:-1], 50, 0, 0, 1)
|
||||||
|
self.rmarm.rm_movej_follow(target_joint[:-1])
|
||||||
|
self.rmarm.rm_set_gripper_position(target_joint[-1], block=False, timeout=2)
|
||||||
|
|
||||||
|
def write_endpose(self, target_endpose: list, gripper: int):
|
||||||
|
self.rmarm.rm_movej_p(target_endpose, 50, 0, 0, 1)
|
||||||
|
self.rmarm.rm_set_gripper_position(gripper, block=False, timeout=2)
|
||||||
|
|
||||||
|
def write_joint_slow(self, target_joint: list):
|
||||||
|
self.rmarm.rm_movej(target_joint, 5, 0, 0, 0)
|
||||||
|
|
||||||
|
def write_joint_canfd(self, target_joint: list):
|
||||||
|
self.rmarm.rm_movej_canfd(target_joint, False)
|
||||||
|
|
||||||
|
def write_endpose_canfd(self, target_pose: list):
|
||||||
|
self.rmarm.rm_movep_canfd(target_pose, False)
|
||||||
|
|
||||||
|
def write_gripper(self, gripper: int):
|
||||||
|
self.rmarm.rm_set_gripper_position(gripper, False, 2)
|
||||||
|
|
||||||
|
def read(self) -> Dict:
|
||||||
|
"""
|
||||||
|
- 机械臂关节消息,单位1度;[-1, 1]
|
||||||
|
- 机械臂夹爪消息,[-1, 1]
|
||||||
|
"""
|
||||||
|
joint_msg = self.rmarm.rm_get_current_arm_state()[1]
|
||||||
|
joint_state = joint_msg['joint']
|
||||||
|
|
||||||
|
gripper_msg = self.rmarm.rm_get_gripper_state()[1]
|
||||||
|
gripper_state = gripper_msg['actpos']
|
||||||
|
|
||||||
|
return {
|
||||||
|
"joint_1": joint_state[0]/180,
|
||||||
|
"joint_2": joint_state[1]/180,
|
||||||
|
"joint_3": joint_state[2]/180,
|
||||||
|
"joint_4": joint_state[3]/180,
|
||||||
|
"joint_5": joint_state[4]/180,
|
||||||
|
"joint_6": joint_state[5]/180,
|
||||||
|
"gripper": (gripper_state-500)/500
|
||||||
|
}
|
||||||
|
|
||||||
|
def read_current_arm_joint_state(self):
|
||||||
|
return self.rmarm.rm_get_current_arm_state()[1]['joint']
|
||||||
|
|
||||||
|
def read_current_arm_endpose_state(self):
|
||||||
|
return self.rmarm.rm_get_current_arm_state()[1]['pose']
|
||||||
|
|
||||||
|
def safe_disconnect(self):
|
||||||
|
"""
|
||||||
|
Move to safe disconnect position
|
||||||
|
"""
|
||||||
|
self.write(target_joint=self.safe_disable_position)
|
||||||
|
# 断开所有连接,销毁线程
|
||||||
|
RoboticArm.rm_destory()
|
||||||
@@ -44,6 +44,11 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
|
|||||||
|
|
||||||
motors_buses[key] = FeetechMotorsBus(cfg)
|
motors_buses[key] = FeetechMotorsBus(cfg)
|
||||||
|
|
||||||
|
elif cfg.type == "realman":
|
||||||
|
from lerobot.common.robot_devices.motors.realman import RealmanMotorsBus
|
||||||
|
|
||||||
|
motors_buses[key] = RealmanMotorsBus(cfg)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
|
||||||
|
|
||||||
@@ -65,3 +70,7 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
raise ValueError(f"The motor type '{motor_type}' is not valid.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_motor_names(arm: dict[str, MotorsBus]) -> list:
|
||||||
|
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]
|
||||||
@@ -27,6 +27,7 @@ from lerobot.common.robot_devices.motors.configs import (
|
|||||||
DynamixelMotorsBusConfig,
|
DynamixelMotorsBusConfig,
|
||||||
FeetechMotorsBusConfig,
|
FeetechMotorsBusConfig,
|
||||||
MotorsBusConfig,
|
MotorsBusConfig,
|
||||||
|
RealmanMotorsBusConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -674,3 +675,91 @@ class LeKiwiRobotConfig(RobotConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
mock: bool = False
|
mock: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@RobotConfig.register_subclass("realman")
|
||||||
|
@dataclass
|
||||||
|
class RealmanRobotConfig(RobotConfig):
|
||||||
|
inference_time: bool = False
|
||||||
|
max_gripper: int = 990
|
||||||
|
min_gripper: int = 10
|
||||||
|
servo_config_file: str = "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
left_follower_arm: dict[str, MotorsBusConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"main": RealmanMotorsBusConfig(
|
||||||
|
ip = "192.168.3.18",
|
||||||
|
port = 8080,
|
||||||
|
motors={
|
||||||
|
# name: (index, model)
|
||||||
|
"joint_1": [1, "realman"],
|
||||||
|
"joint_2": [2, "realman"],
|
||||||
|
"joint_3": [3, "realman"],
|
||||||
|
"joint_4": [4, "realman"],
|
||||||
|
"joint_5": [5, "realman"],
|
||||||
|
"joint_6": [6, "realman"],
|
||||||
|
"gripper": [7, "realman"],
|
||||||
|
},
|
||||||
|
init_joint = {'joint': [-90, 90, 90, 90, 90, -90, 10]}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cameras: dict[str, CameraConfig] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
# "one": OpenCVCameraConfig(
|
||||||
|
# camera_index=4,
|
||||||
|
# fps=30,
|
||||||
|
# width=640,
|
||||||
|
# height=480,
|
||||||
|
# ),
|
||||||
|
"left": IntelRealSenseCameraConfig(
|
||||||
|
serial_number="153122077516",
|
||||||
|
fps=30,
|
||||||
|
width=640,
|
||||||
|
height=480,
|
||||||
|
use_depth=False
|
||||||
|
),
|
||||||
|
# "right": IntelRealSenseCameraConfig(
|
||||||
|
# serial_number="405622075165",
|
||||||
|
# fps=30,
|
||||||
|
# width=640,
|
||||||
|
# height=480,
|
||||||
|
# use_depth=False
|
||||||
|
# ),
|
||||||
|
"front": IntelRealSenseCameraConfig(
|
||||||
|
serial_number="145422072751",
|
||||||
|
fps=30,
|
||||||
|
width=640,
|
||||||
|
height=480,
|
||||||
|
use_depth=False
|
||||||
|
),
|
||||||
|
"high": IntelRealSenseCameraConfig(
|
||||||
|
serial_number="145422072193",
|
||||||
|
fps=30,
|
||||||
|
width=640,
|
||||||
|
height=480,
|
||||||
|
use_depth=False
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# right_follower_arm: dict[str, MotorsBusConfig] = field(
|
||||||
|
# default_factory=lambda: {
|
||||||
|
# "main": RealmanMotorsBusConfig(
|
||||||
|
# ip = "192.168.3.19",
|
||||||
|
# port = 8080,
|
||||||
|
# motors={
|
||||||
|
# # name: (index, model)
|
||||||
|
# "joint_1": [1, "realman"],
|
||||||
|
# "joint_2": [2, "realman"],
|
||||||
|
# "joint_3": [3, "realman"],
|
||||||
|
# "joint_4": [4, "realman"],
|
||||||
|
# "joint_5": [5, "realman"],
|
||||||
|
# "joint_6": [6, "realman"],
|
||||||
|
# "gripper": (7, "realman"),
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
# }
|
||||||
|
# )
|
||||||
|
|||||||
292
lerobot/common/robot_devices/robots/realman.py
Normal file
292
lerobot/common/robot_devices/robots/realman.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
Teleoperation Realman with a PS5 controller and
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from dataclasses import dataclass, field, replace
|
||||||
|
from collections import deque
|
||||||
|
from lerobot.common.robot_devices.teleop.gamepad import HybridController
|
||||||
|
from lerobot.common.robot_devices.motors.utils import get_motor_names, make_motors_buses_from_configs
|
||||||
|
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
|
||||||
|
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
|
||||||
|
from lerobot.common.robot_devices.robots.configs import RealmanRobotConfig
|
||||||
|
|
||||||
|
|
||||||
|
class RealmanRobot:
|
||||||
|
def __init__(self, config: RealmanRobotConfig | None = None, **kwargs):
|
||||||
|
if config is None:
|
||||||
|
config = RealmanRobotConfig()
|
||||||
|
# Overwrite config arguments using kwargs
|
||||||
|
self.config = replace(config, **kwargs)
|
||||||
|
self.robot_type = self.config.type
|
||||||
|
self.inference_time = self.config.inference_time # if it is inference time
|
||||||
|
|
||||||
|
# build cameras
|
||||||
|
self.cameras = make_cameras_from_configs(self.config.cameras)
|
||||||
|
|
||||||
|
# build realman motors
|
||||||
|
self.piper_motors = make_motors_buses_from_configs(self.config.left_follower_arm)
|
||||||
|
self.arm = self.piper_motors['main']
|
||||||
|
|
||||||
|
# build init teleop info
|
||||||
|
self.init_info = {
|
||||||
|
'init_joint': self.arm.init_joint_position,
|
||||||
|
'init_pose': self.arm.init_pose,
|
||||||
|
'max_gripper': config.max_gripper,
|
||||||
|
'min_gripper': config.min_gripper,
|
||||||
|
'servo_config_file': config.servo_config_file
|
||||||
|
}
|
||||||
|
|
||||||
|
# build state-action cache
|
||||||
|
self.joint_queue = deque(maxlen=2)
|
||||||
|
self.last_endpose = self.arm.init_pose
|
||||||
|
|
||||||
|
# build gamepad teleop
|
||||||
|
if not self.inference_time:
|
||||||
|
self.teleop = HybridController(self.init_info)
|
||||||
|
else:
|
||||||
|
self.teleop = None
|
||||||
|
|
||||||
|
self.logs = {}
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def camera_features(self) -> dict:
|
||||||
|
cam_ft = {}
|
||||||
|
for cam_key, cam in self.cameras.items():
|
||||||
|
key = f"observation.images.{cam_key}"
|
||||||
|
cam_ft[key] = {
|
||||||
|
"shape": (cam.height, cam.width, cam.channels),
|
||||||
|
"names": ["height", "width", "channels"],
|
||||||
|
"info": None,
|
||||||
|
}
|
||||||
|
return cam_ft
|
||||||
|
|
||||||
|
|
||||||
|
@property
|
||||||
|
def motor_features(self) -> dict:
|
||||||
|
action_names = get_motor_names(self.piper_motors)
|
||||||
|
state_names = get_motor_names(self.piper_motors)
|
||||||
|
return {
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(action_names),),
|
||||||
|
"names": action_names,
|
||||||
|
},
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(state_names),),
|
||||||
|
"names": state_names,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_camera(self):
|
||||||
|
return len(self.cameras) > 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_cameras(self):
|
||||||
|
return len(self.cameras)
|
||||||
|
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Connect RealmanArm and cameras"""
|
||||||
|
if self.is_connected:
|
||||||
|
raise RobotDeviceAlreadyConnectedError(
|
||||||
|
"RealmanArm is already connected. Do not run `robot.connect()` twice."
|
||||||
|
)
|
||||||
|
|
||||||
|
# connect RealmanArm
|
||||||
|
self.arm.connect(enable=True)
|
||||||
|
print("RealmanArm conneted")
|
||||||
|
|
||||||
|
# connect cameras
|
||||||
|
for name in self.cameras:
|
||||||
|
self.cameras[name].connect()
|
||||||
|
self.is_connected = self.is_connected and self.cameras[name].is_connected
|
||||||
|
print(f"camera {name} conneted")
|
||||||
|
|
||||||
|
print("All connected")
|
||||||
|
self.is_connected = True
|
||||||
|
|
||||||
|
self.run_calibration()
|
||||||
|
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""move to home position, disenable piper and cameras"""
|
||||||
|
# move piper to home position, disable
|
||||||
|
if not self.inference_time:
|
||||||
|
self.teleop.stop()
|
||||||
|
|
||||||
|
# disconnect piper
|
||||||
|
self.arm.safe_disconnect()
|
||||||
|
print("RealmanArm disable after 5 seconds")
|
||||||
|
time.sleep(5)
|
||||||
|
self.arm.connect(enable=False)
|
||||||
|
|
||||||
|
# disconnect cameras
|
||||||
|
if len(self.cameras) > 0:
|
||||||
|
for cam in self.cameras.values():
|
||||||
|
cam.disconnect()
|
||||||
|
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
|
||||||
|
def run_calibration(self):
|
||||||
|
"""move piper to the home position"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise ConnectionError()
|
||||||
|
|
||||||
|
self.arm.apply_calibration()
|
||||||
|
if not self.inference_time:
|
||||||
|
self.teleop.reset()
|
||||||
|
|
||||||
|
|
||||||
|
def teleop_step(
|
||||||
|
self, record_data=False
|
||||||
|
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||||
|
if not self.is_connected:
|
||||||
|
raise ConnectionError()
|
||||||
|
|
||||||
|
if self.teleop is None and self.inference_time:
|
||||||
|
self.teleop = HybridController(self.init_info)
|
||||||
|
|
||||||
|
# read target pose state as
|
||||||
|
before_read_t = time.perf_counter()
|
||||||
|
state = self.arm.read() # read current joint position from robot
|
||||||
|
action = self.teleop.get_action() # target joint position and pose end pos from gamepad
|
||||||
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
|
if action['control_mode'] == 'joint':
|
||||||
|
# 关节控制模式(主模式)
|
||||||
|
current_pose = self.arm.read_current_arm_endpose_state()
|
||||||
|
self.teleop.update_endpose_state(current_pose)
|
||||||
|
|
||||||
|
target_joints = action['joint_angles'][:-1]
|
||||||
|
self.arm.write_gripper(action['gripper'])
|
||||||
|
print(action['gripper'])
|
||||||
|
if action['master_controller_status']['infrared'] == 1:
|
||||||
|
if action['master_controller_status']['button'] == 1:
|
||||||
|
self.arm.write_joint_canfd(target_joints)
|
||||||
|
else:
|
||||||
|
self.arm.write_joint_slow(target_joints)
|
||||||
|
|
||||||
|
# do action
|
||||||
|
before_write_t = time.perf_counter()
|
||||||
|
self.joint_queue.append(list(self.arm.read().values()))
|
||||||
|
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||||
|
|
||||||
|
else:
|
||||||
|
target_pose = list(action['end_pose'])
|
||||||
|
# do action
|
||||||
|
before_write_t = time.perf_counter()
|
||||||
|
if self.last_endpose != target_pose:
|
||||||
|
self.arm.write_endpose_canfd(target_pose)
|
||||||
|
self.last_endpose = target_pose
|
||||||
|
self.arm.write_gripper(action['gripper'])
|
||||||
|
|
||||||
|
target_joints = self.arm.read_current_arm_joint_state()
|
||||||
|
self.joint_queue.append(list(self.arm.read().values()))
|
||||||
|
self.teleop.update_joint_state(target_joints)
|
||||||
|
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
|
||||||
|
|
||||||
|
if not record_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
state = torch.as_tensor(list(self.joint_queue[0]), dtype=torch.float32)
|
||||||
|
action = torch.as_tensor(list(self.joint_queue[-1]), dtype=torch.float32)
|
||||||
|
|
||||||
|
# Capture images from cameras
|
||||||
|
images = {}
|
||||||
|
for name in self.cameras:
|
||||||
|
before_camread_t = time.perf_counter()
|
||||||
|
images[name] = self.cameras[name].async_read()
|
||||||
|
images[name] = torch.from_numpy(images[name])
|
||||||
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
|
# Populate output dictionnaries
|
||||||
|
obs_dict, action_dict = {}, {}
|
||||||
|
obs_dict["observation.state"] = state
|
||||||
|
action_dict["action"] = action
|
||||||
|
for name in self.cameras:
|
||||||
|
obs_dict[f"observation.images.{name}"] = images[name]
|
||||||
|
|
||||||
|
return obs_dict, action_dict
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Write the predicted actions from policy to the motors"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(
|
||||||
|
"Piper is not connected. You need to run `robot.connect()`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# send to motors, torch to list
|
||||||
|
target_joints = action.tolist()
|
||||||
|
len_joint = len(target_joints) - 1
|
||||||
|
target_joints = [target_joints[i]*180 for i in range(len_joint)] + [target_joints[-1]]
|
||||||
|
target_joints[-1] = int(target_joints[-1]*500 + 500)
|
||||||
|
self.arm.write(target_joints)
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def capture_observation(self) -> dict:
|
||||||
|
"""capture current images and joint positions"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise RobotDeviceNotConnectedError(
|
||||||
|
"Piper is not connected. You need to run `robot.connect()`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# read current joint positions
|
||||||
|
before_read_t = time.perf_counter()
|
||||||
|
state = self.arm.read() # 6 joints + 1 gripper
|
||||||
|
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
|
|
||||||
|
state = torch.as_tensor(list(state.values()), dtype=torch.float32)
|
||||||
|
|
||||||
|
# read images from cameras
|
||||||
|
images = {}
|
||||||
|
for name in self.cameras:
|
||||||
|
before_camread_t = time.perf_counter()
|
||||||
|
images[name] = self.cameras[name].async_read()
|
||||||
|
images[name] = torch.from_numpy(images[name])
|
||||||
|
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
|
||||||
|
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
|
|
||||||
|
# Populate output dictionnaries and format to pytorch
|
||||||
|
obs_dict = {}
|
||||||
|
obs_dict["observation.state"] = state
|
||||||
|
for name in self.cameras:
|
||||||
|
obs_dict[f"observation.images.{name}"] = images[name]
|
||||||
|
return obs_dict
|
||||||
|
|
||||||
|
def teleop_safety_stop(self):
|
||||||
|
""" move to home position after record one episode """
|
||||||
|
self.run_calibration()
|
||||||
|
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
if self.is_connected:
|
||||||
|
self.disconnect()
|
||||||
|
if not self.inference_time:
|
||||||
|
self.teleop.stop()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
robot = RealmanRobot()
|
||||||
|
robot.connect()
|
||||||
|
# robot.run_calibration()
|
||||||
|
while True:
|
||||||
|
robot.teleop_step(record_data=True)
|
||||||
|
|
||||||
|
robot.capture_observation()
|
||||||
|
dummy_action = torch.Tensor([-0.40586111280653214, 0.5522833506266276, 0.4998166826036241, -0.3539944542778863, -0.524433347913954, 0.9064999898274739, 0.482])
|
||||||
|
robot.send_action(dummy_action)
|
||||||
|
robot.disconnect()
|
||||||
|
print('ok')
|
||||||
@@ -25,6 +25,7 @@ from lerobot.common.robot_devices.robots.configs import (
|
|||||||
So100RobotConfig,
|
So100RobotConfig,
|
||||||
So101RobotConfig,
|
So101RobotConfig,
|
||||||
StretchRobotConfig,
|
StretchRobotConfig,
|
||||||
|
RealmanRobotConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -65,6 +66,9 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
|
|||||||
return StretchRobotConfig(**kwargs)
|
return StretchRobotConfig(**kwargs)
|
||||||
elif robot_type == "lekiwi":
|
elif robot_type == "lekiwi":
|
||||||
return LeKiwiRobotConfig(**kwargs)
|
return LeKiwiRobotConfig(**kwargs)
|
||||||
|
elif robot_type == 'realman':
|
||||||
|
return RealmanRobotConfig(**kwargs)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
raise ValueError(f"Robot type '{robot_type}' is not available.")
|
||||||
|
|
||||||
@@ -78,6 +82,12 @@ def make_robot_from_config(config: RobotConfig):
|
|||||||
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
|
||||||
|
|
||||||
return MobileManipulator(config)
|
return MobileManipulator(config)
|
||||||
|
|
||||||
|
elif isinstance(config, RealmanRobotConfig):
|
||||||
|
from lerobot.common.robot_devices.robots.realman import RealmanRobot
|
||||||
|
|
||||||
|
return RealmanRobot(config)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
from lerobot.common.robot_devices.robots.stretch import StretchRobot
|
||||||
|
|
||||||
|
|||||||
466
lerobot/common/robot_devices/teleop/gamepad.py
Normal file
466
lerobot/common/robot_devices/teleop/gamepad.py
Normal file
@@ -0,0 +1,466 @@
|
|||||||
|
import pygame
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import serial
|
||||||
|
import binascii
|
||||||
|
import logging
|
||||||
|
import yaml
|
||||||
|
from typing import Dict
|
||||||
|
from Robotic_Arm.rm_robot_interface import *
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ServoArm:
|
||||||
|
def __init__(self, config_file="config.yaml"):
|
||||||
|
"""初始化机械臂的串口连接并发送初始数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file (str): 配置文件的路径。
|
||||||
|
"""
|
||||||
|
self.config = self._load_config(config_file)
|
||||||
|
self.port = self.config["port"]
|
||||||
|
self.baudrate = self.config["baudrate"]
|
||||||
|
self.joint_hex_data = self.config["joint_hex_data"]
|
||||||
|
self.control_hex_data = self.config["control_hex_data"]
|
||||||
|
self.arm_axis = self.config.get("arm_axis", 7)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
|
||||||
|
self.bytes_to_send = binascii.unhexlify(self.joint_hex_data.replace(" ", ""))
|
||||||
|
self.serial_conn.write(self.bytes_to_send)
|
||||||
|
time.sleep(1)
|
||||||
|
self.connected = True
|
||||||
|
logging.info(f"串口连接成功: {self.port}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"串口连接失败: {e}")
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
def _load_config(self, config_file):
|
||||||
|
"""加载配置文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file (str): 配置文件的路径。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 配置文件内容。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(config_file, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"配置文件加载失败: {e}")
|
||||||
|
# 返回默认配置
|
||||||
|
return {
|
||||||
|
"port": "/dev/ttyUSB0",
|
||||||
|
"baudrate": 460800,
|
||||||
|
"joint_hex_data": "55 AA 02 00 00 67",
|
||||||
|
"control_hex_data": "55 AA 08 00 00 B9",
|
||||||
|
"arm_axis": 6
|
||||||
|
}
|
||||||
|
|
||||||
|
def _bytes_to_signed_int(self, byte_data):
|
||||||
|
"""将字节数据转换为有符号整数。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
byte_data (bytes): 字节数据。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 有符号整数。
|
||||||
|
"""
|
||||||
|
return int.from_bytes(byte_data, byteorder="little", signed=True)
|
||||||
|
|
||||||
|
def _parse_joint_data(self, hex_received):
|
||||||
|
"""解析接收到的十六进制数据并提取关节数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hex_received (str): 接收到的十六进制字符串数据。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 解析后的关节数据。
|
||||||
|
"""
|
||||||
|
logging.debug(f"hex_received: {hex_received}")
|
||||||
|
joints = {}
|
||||||
|
for i in range(self.arm_axis):
|
||||||
|
start = 14 + i * 10
|
||||||
|
end = start + 8
|
||||||
|
joint_hex = hex_received[start:end]
|
||||||
|
joint_byte_data = bytearray.fromhex(joint_hex)
|
||||||
|
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
|
||||||
|
joints[f"joint_{i+1}"] = joint_value
|
||||||
|
grasp_start = 14 + self.arm_axis*10
|
||||||
|
grasp_hex = hex_received[grasp_start:grasp_start+8]
|
||||||
|
grasp_byte_data = bytearray.fromhex(grasp_hex)
|
||||||
|
# 夹爪进行归一化处理
|
||||||
|
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
|
||||||
|
|
||||||
|
joints["grasp"] = grasp_value
|
||||||
|
return joints
|
||||||
|
|
||||||
|
def _parse_controller_data(self, hex_received):
|
||||||
|
status = {
|
||||||
|
'infrared': 0,
|
||||||
|
'button': 0
|
||||||
|
}
|
||||||
|
if len(hex_received) == 18:
|
||||||
|
status['infrared'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[12:14]))
|
||||||
|
status['button'] = self._bytes_to_signed_int(bytearray.fromhex(hex_received[14:16]))
|
||||||
|
# print(infrared)
|
||||||
|
return status
|
||||||
|
|
||||||
|
def get_joint_actions(self):
|
||||||
|
"""从串口读取数据并解析关节动作。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含关节数据的字典。
|
||||||
|
"""
|
||||||
|
if not self.connected:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.serial_conn.write(self.bytes_to_send)
|
||||||
|
time.sleep(0.02)
|
||||||
|
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||||
|
if len(bytes_received) == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||||
|
actions = self._parse_joint_data(hex_received)
|
||||||
|
return actions
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"读取串口数据错误: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_controller_status(self):
|
||||||
|
bytes_to_send = binascii.unhexlify(self.control_hex_data.replace(" ", ""))
|
||||||
|
self.serial_conn.write(bytes_to_send)
|
||||||
|
time.sleep(0.02)
|
||||||
|
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
|
||||||
|
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
|
||||||
|
# print("control status:", hex_received)
|
||||||
|
status = self._parse_controller_data(hex_received)
|
||||||
|
return status
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭串口连接"""
|
||||||
|
if self.connected and hasattr(self, 'serial_conn'):
|
||||||
|
self.serial_conn.close()
|
||||||
|
self.connected = False
|
||||||
|
logging.info("串口连接已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
class HybridController:
|
||||||
|
def __init__(self, init_info):
|
||||||
|
# 初始化pygame和手柄
|
||||||
|
pygame.init()
|
||||||
|
pygame.joystick.init()
|
||||||
|
|
||||||
|
# 检查是否有连接的手柄
|
||||||
|
if pygame.joystick.get_count() == 0:
|
||||||
|
raise Exception("未检测到手柄")
|
||||||
|
|
||||||
|
# 初始化手柄
|
||||||
|
self.joystick = pygame.joystick.Joystick(0)
|
||||||
|
self.joystick.init()
|
||||||
|
# 摇杆死区
|
||||||
|
self.deadzone = 0.15
|
||||||
|
# 控制模式: True为关节控制(主模式),False为末端控制
|
||||||
|
self.joint_control_mode = True
|
||||||
|
# 精细控制模式
|
||||||
|
self.fine_control_mode = False
|
||||||
|
|
||||||
|
# 初始化末端姿态和关节角度
|
||||||
|
self.init_joint = init_info['init_joint']
|
||||||
|
self.init_pose = init_info.get('init_pose', [0]*6)
|
||||||
|
self.max_gripper = init_info['max_gripper']
|
||||||
|
self.min_gripper = init_info['min_gripper']
|
||||||
|
servo_config_file = init_info['servo_config_file']
|
||||||
|
self.joint = self.init_joint.copy()
|
||||||
|
self.pose = self.init_pose.copy()
|
||||||
|
self.pose_speeds = [0.0] * 6
|
||||||
|
self.joint_speeds = [0.0] * 6
|
||||||
|
self.tozero = False
|
||||||
|
|
||||||
|
# 主臂关节状态
|
||||||
|
self.master_joint_actions = {}
|
||||||
|
self.master_controller_status = {}
|
||||||
|
self.use_master_arm = False
|
||||||
|
|
||||||
|
# 末端位姿限制
|
||||||
|
self.pose_limits = [
|
||||||
|
(-0.800, 0.800), # X (m)
|
||||||
|
(-0.800, 0.800), # Y (m)
|
||||||
|
(-0.800, 0.800), # Z (m)
|
||||||
|
(-3.14, 3.14), # RX (rad)
|
||||||
|
(-3.14, 3.14), # RY (rad)
|
||||||
|
(-3.14, 3.14) # RZ (rad)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 关节角度限制 (度)
|
||||||
|
self.joint_limits = [
|
||||||
|
(-180, 180), # joint 1
|
||||||
|
(-180, 180), # joint 2
|
||||||
|
(-180, 180), # joint 3
|
||||||
|
(-180, 180), # joint 4
|
||||||
|
(-180, 180), # joint 5
|
||||||
|
(-180, 180) # joint 6
|
||||||
|
]
|
||||||
|
|
||||||
|
# 控制参数
|
||||||
|
self.linear_step = 0.002 # 线性移动步长(m)
|
||||||
|
self.angular_step = 0.01 # 角度步长(rad)
|
||||||
|
|
||||||
|
# 夹爪状态和速度
|
||||||
|
self.gripper_speed = 10
|
||||||
|
self.gripper = self.min_gripper
|
||||||
|
|
||||||
|
# 初始化串口通信(主臂关节状态获取)
|
||||||
|
self.servo_arm = None
|
||||||
|
if servo_config_file:
|
||||||
|
try:
|
||||||
|
self.servo_arm = ServoArm(servo_config_file)
|
||||||
|
self.use_master_arm = True
|
||||||
|
logging.info("串口主臂连接成功,启用主从控制模式")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"串口主臂连接失败: {e}")
|
||||||
|
self.use_master_arm = False
|
||||||
|
|
||||||
|
# 启动更新线程
|
||||||
|
self.running = True
|
||||||
|
self.thread = threading.Thread(target=self.update_controller)
|
||||||
|
self.thread.start()
|
||||||
|
|
||||||
|
print("混合控制器已启动")
|
||||||
|
print("主控制模式: 关节控制")
|
||||||
|
if self.use_master_arm:
|
||||||
|
print("主从控制: 启用")
|
||||||
|
print("Back按钮: 切换控制模式(关节/末端)")
|
||||||
|
print("L3按钮: 切换精细控制模式")
|
||||||
|
print("Start按钮: 重置到初始位置")
|
||||||
|
|
||||||
|
def _apply_nonlinear_mapping(self, value):
|
||||||
|
"""应用非线性映射以提高控制精度"""
|
||||||
|
sign = 1 if value >= 0 else -1
|
||||||
|
return sign * (abs(value) ** 2)
|
||||||
|
|
||||||
|
def _normalize_angle(self, angle):
|
||||||
|
"""将角度归一化到[-π, π]范围内"""
|
||||||
|
import math
|
||||||
|
while angle > math.pi:
|
||||||
|
angle -= 2 * math.pi
|
||||||
|
while angle < -math.pi:
|
||||||
|
angle += 2 * math.pi
|
||||||
|
return angle
|
||||||
|
|
||||||
|
def update_controller(self):
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
pygame.event.pump()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"控制器错误: {e}")
|
||||||
|
self.stop()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查控制模式切换 (Back按钮)
|
||||||
|
if self.joystick.get_button(6): # Back按钮
|
||||||
|
self.joint_control_mode = not self.joint_control_mode
|
||||||
|
mode_str = "关节控制" if self.joint_control_mode else "末端位姿控制"
|
||||||
|
print(f"切换到{mode_str}模式")
|
||||||
|
time.sleep(0.3) # 防止多次触发
|
||||||
|
|
||||||
|
# 检查精细控制模式切换 (L3按钮)
|
||||||
|
if self.joystick.get_button(10): # L3按钮
|
||||||
|
self.fine_control_mode = not self.fine_control_mode
|
||||||
|
print(f"切换到{'精细' if self.fine_control_mode else '普通'}控制模式")
|
||||||
|
time.sleep(0.3) # 防止多次触发
|
||||||
|
|
||||||
|
# 检查重置按钮 (Start按钮)
|
||||||
|
if self.joystick.get_button(7): # Start按钮
|
||||||
|
print("重置机械臂到初始位置...")
|
||||||
|
# print("init_joint", self.init_joint.copy())
|
||||||
|
self.tozero = True
|
||||||
|
self.joint = self.init_joint.copy()
|
||||||
|
self.pose = self.init_pose.copy()
|
||||||
|
self.pose_speeds = [0.0] * 6
|
||||||
|
self.joint_speeds = [0.0] * 6
|
||||||
|
self.gripper_speed = 10
|
||||||
|
self.gripper = self.min_gripper
|
||||||
|
print("机械臂已重置到初始位置")
|
||||||
|
time.sleep(0.3) # 防止多次触发
|
||||||
|
|
||||||
|
# 从串口获取主臂关节状态
|
||||||
|
if self.servo_arm and self.servo_arm.connected:
|
||||||
|
try:
|
||||||
|
self.master_joint_actions = self.servo_arm.get_joint_actions()
|
||||||
|
self.master_controller_status = self.servo_arm.get_controller_status()
|
||||||
|
if self.master_joint_actions:
|
||||||
|
logging.debug(f"主臂关节状态: {self.master_joint_actions}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取主臂状态错误: {e}")
|
||||||
|
self.master_joint_actions = {}
|
||||||
|
# print(self.master_joint_actions)
|
||||||
|
|
||||||
|
# 根据控制模式更新相应的控制逻辑
|
||||||
|
if self.joint_control_mode:
|
||||||
|
# 关节控制模式下,优先使用主臂数据,Xbox作为辅助
|
||||||
|
self.update_joint_control()
|
||||||
|
else:
|
||||||
|
# 末端控制模式,使用Xbox控制
|
||||||
|
self.update_end_pose()
|
||||||
|
time.sleep(0.02)
|
||||||
|
# print('gripper:', self.gripper)
|
||||||
|
|
||||||
|
def update_joint_control(self):
|
||||||
|
"""更新关节角度控制 - 优先使用主臂数据"""
|
||||||
|
if self.use_master_arm and self.master_joint_actions:
|
||||||
|
# 主从控制模式:直接使用主臂的关节角度
|
||||||
|
try:
|
||||||
|
# 将主臂关节角度映射到从臂
|
||||||
|
for i in range(6): # 假设只有6个关节需要控制
|
||||||
|
joint_key = f"joint_{i+1}"
|
||||||
|
if joint_key in self.master_joint_actions:
|
||||||
|
# 直接使用主臂的关节角度(已经是度数)
|
||||||
|
self.joint[i] = self.master_joint_actions[joint_key]
|
||||||
|
|
||||||
|
# 应用关节限制
|
||||||
|
min_val, max_val = self.joint_limits[i]
|
||||||
|
self.joint[i] = max(min_val, min(max_val, self.joint[i]))
|
||||||
|
|
||||||
|
# print(self.joint)
|
||||||
|
logging.debug(f"主臂关节映射到从臂: {self.joint[:6]}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"主臂数据映射错误: {e}")
|
||||||
|
|
||||||
|
# 如果有主臂夹爪数据,使用主臂夹爪状态
|
||||||
|
if self.use_master_arm and "grasp" in self.master_joint_actions:
|
||||||
|
self.gripper = self.master_joint_actions["grasp"] * 1000
|
||||||
|
self.joint[-1] = self.gripper
|
||||||
|
|
||||||
|
|
||||||
|
def update_end_pose(self):
|
||||||
|
"""更新末端位姿控制"""
|
||||||
|
# 根据控制模式调整步长
|
||||||
|
current_linear_step = self.linear_step * (0.1 if self.fine_control_mode else 1.0)
|
||||||
|
current_angular_step = self.angular_step * (0.1 if self.fine_control_mode else 1.0)
|
||||||
|
|
||||||
|
# 方向键控制XY
|
||||||
|
hat = self.joystick.get_hat(0)
|
||||||
|
hat_up = hat[1] == 1 # Y+
|
||||||
|
hat_down = hat[1] == -1 # Y-
|
||||||
|
hat_left = hat[0] == -1 # X-
|
||||||
|
hat_right = hat[0] == 1 # X+
|
||||||
|
|
||||||
|
# 右摇杆控制Z
|
||||||
|
right_y_raw = -self.joystick.get_axis(4)
|
||||||
|
# 左摇杆控制RZ
|
||||||
|
left_y_raw = -self.joystick.get_axis(1)
|
||||||
|
|
||||||
|
# 应用死区
|
||||||
|
right_y = 0.0 if abs(right_y_raw) < self.deadzone else right_y_raw
|
||||||
|
left_y = 0.0 if abs(left_y_raw) < self.deadzone else left_y_raw
|
||||||
|
|
||||||
|
# 计算各轴速度
|
||||||
|
self.pose_speeds[1] = current_linear_step if hat_up else (-current_linear_step if hat_down else 0.0) # Y
|
||||||
|
self.pose_speeds[0] = -current_linear_step if hat_left else (current_linear_step if hat_right else 0.0) # X
|
||||||
|
|
||||||
|
# 设置Z速度(右摇杆Y轴控制)
|
||||||
|
z_mapping = self._apply_nonlinear_mapping(right_y)
|
||||||
|
self.pose_speeds[2] = z_mapping * current_linear_step # Z
|
||||||
|
|
||||||
|
# L1/R1控制RX旋转
|
||||||
|
LB = self.joystick.get_button(4) # RX-
|
||||||
|
RB = self.joystick.get_button(5) # RX+
|
||||||
|
self.pose_speeds[3] = (-current_angular_step if LB else (current_angular_step if RB else 0.0))
|
||||||
|
|
||||||
|
# △/□控制RY旋转
|
||||||
|
triangle = self.joystick.get_button(2) # RY+
|
||||||
|
square = self.joystick.get_button(3) # RY-
|
||||||
|
self.pose_speeds[4] = (current_angular_step if triangle else (-current_angular_step if square else 0.0))
|
||||||
|
|
||||||
|
# 左摇杆Y轴控制RZ旋转
|
||||||
|
rz_mapping = self._apply_nonlinear_mapping(left_y)
|
||||||
|
self.pose_speeds[5] = rz_mapping * current_angular_step * 2 # RZ
|
||||||
|
|
||||||
|
# 夹爪控制(圈/叉)
|
||||||
|
circle = self.joystick.get_button(1) # 夹爪开
|
||||||
|
cross = self.joystick.get_button(0) # 夹爪关
|
||||||
|
if circle:
|
||||||
|
self.gripper = min(self.max_gripper, self.gripper + self.gripper_speed)
|
||||||
|
elif cross:
|
||||||
|
self.gripper = max(self.min_gripper, self.gripper - self.gripper_speed)
|
||||||
|
|
||||||
|
# 更新末端位姿
|
||||||
|
for i in range(6):
|
||||||
|
self.pose[i] += self.pose_speeds[i]
|
||||||
|
|
||||||
|
# 角度归一化处理
|
||||||
|
for i in range(3, 6):
|
||||||
|
self.pose[i] = self._normalize_angle(self.pose[i])
|
||||||
|
|
||||||
|
def update_joint_state(self, joint):
|
||||||
|
self.joint = joint
|
||||||
|
# self.tozero = False
|
||||||
|
|
||||||
|
def update_endpose_state(self, end_pose):
|
||||||
|
self.pose = end_pose
|
||||||
|
# self.tozero = False
|
||||||
|
|
||||||
|
def update_tozero_state(self, tozero):
|
||||||
|
self.tozero = tozero
|
||||||
|
|
||||||
|
|
||||||
|
def get_action(self) -> Dict:
|
||||||
|
"""获取当前控制命令"""
|
||||||
|
return {
|
||||||
|
'control_mode': 'joint' if self.joint_control_mode else 'end_pose',
|
||||||
|
'use_master_arm': self.use_master_arm,
|
||||||
|
'master_joint_actions': self.master_joint_actions,
|
||||||
|
'master_controller_status': self.master_controller_status,
|
||||||
|
'end_pose': self.pose,
|
||||||
|
'joint_angles': self.joint,
|
||||||
|
'gripper': int(self.gripper),
|
||||||
|
'tozero': self.tozero
|
||||||
|
}
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""停止控制器"""
|
||||||
|
self.running = False
|
||||||
|
if self.thread.is_alive():
|
||||||
|
self.thread.join()
|
||||||
|
if self.servo_arm:
|
||||||
|
self.servo_arm.close()
|
||||||
|
pygame.quit()
|
||||||
|
print("混合控制器已退出")
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""重置到初始状态"""
|
||||||
|
self.joint = self.init_joint.copy()
|
||||||
|
self.pose = self.init_pose.copy()
|
||||||
|
self.pose_speeds = [0.0] * 6
|
||||||
|
self.joint_speeds = [0.0] * 6
|
||||||
|
self.gripper_speed = 10
|
||||||
|
self.gripper = self.min_gripper
|
||||||
|
print("已重置到初始状态")
|
||||||
|
|
||||||
|
|
||||||
|
# 使用示例
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 初始化睿尔曼机械臂
|
||||||
|
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||||
|
# 创建机械臂连接
|
||||||
|
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||||
|
print(f"机械臂连接ID: {handle.id}")
|
||||||
|
init_pose = arm.rm_get_current_arm_state()[1]['pose']
|
||||||
|
|
||||||
|
with open('/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/realman_mix.yaml', "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
config['init_pose'] = init_pose
|
||||||
|
arm_controller = HybridController(config)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
print(arm_controller.get_action())
|
||||||
|
time.sleep(0.1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
arm_controller.stop()
|
||||||
4
lerobot/common/robot_devices/teleop/realman_mix.yaml
Normal file
4
lerobot/common/robot_devices/teleop/realman_mix.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
init_joint: [-90, 90, 90, -90, -90, 90]
|
||||||
|
max_gripper: 990
|
||||||
|
min_gripper: 10
|
||||||
|
servo_config_file: "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"
|
||||||
6
lerobot/common/robot_devices/teleop/servo_arm.yaml
Normal file
6
lerobot/common/robot_devices/teleop/servo_arm.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
port: /dev/ttyUSB0
|
||||||
|
right_port: /dev/ttyUSB1
|
||||||
|
baudrate: 460800
|
||||||
|
joint_hex_data: "55 AA 02 00 00 67"
|
||||||
|
control_hex_data: "55 AA 08 00 00 B9"
|
||||||
|
arm_axis: 6
|
||||||
@@ -175,7 +175,8 @@ def say(text, blocking=False):
|
|||||||
cmd = ["say", text]
|
cmd = ["say", text]
|
||||||
|
|
||||||
elif system == "Linux":
|
elif system == "Linux":
|
||||||
cmd = ["spd-say", text]
|
# cmd = ["spd-say", text]
|
||||||
|
cmd = ["edge-playback", "-t", text]
|
||||||
if blocking:
|
if blocking:
|
||||||
cmd.append("--wait")
|
cmd.append("--wait")
|
||||||
|
|
||||||
|
|||||||
@@ -273,7 +273,6 @@ def record(
|
|||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
|
||||||
if not robot.is_connected:
|
if not robot.is_connected:
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|
||||||
@@ -290,6 +289,9 @@ def record(
|
|||||||
if has_method(robot, "teleop_safety_stop"):
|
if has_method(robot, "teleop_safety_stop"):
|
||||||
robot.teleop_safety_stop()
|
robot.teleop_safety_stop()
|
||||||
|
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
|
||||||
recorded_episodes = 0
|
recorded_episodes = 0
|
||||||
while True:
|
while True:
|
||||||
if recorded_episodes >= cfg.num_episodes:
|
if recorded_episodes >= cfg.num_episodes:
|
||||||
|
|||||||
156
realman.md
Normal file
156
realman.md
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
# Install
|
||||||
|
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||||
|
```bash
|
||||||
|
conda create -y -n lerobot python=3.10
|
||||||
|
conda activate lerobot
|
||||||
|
```
|
||||||
|
|
||||||
|
Install 🤗 LeRobot:
|
||||||
|
```bash
|
||||||
|
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
pip install edge-tts
|
||||||
|
sudo apt install mpv -y
|
||||||
|
|
||||||
|
# pip uninstall numpy
|
||||||
|
# pip install numpy==1.26.0
|
||||||
|
# pip install pynput
|
||||||
|
```
|
||||||
|
|
||||||
|
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
|
||||||
|
```bash
|
||||||
|
conda install ffmpeg=7.1.1 -c conda-forge
|
||||||
|
# pip uninstall opencv-python
|
||||||
|
# conda install "opencv>=4.10.0"
|
||||||
|
```
|
||||||
|
|
||||||
|
Install Realman SDK:
|
||||||
|
```bash
|
||||||
|
pip install Robotic_Arm==1.0.4.1
|
||||||
|
pip install pygame
|
||||||
|
```
|
||||||
|
|
||||||
|
# piper集成lerobot
|
||||||
|
见lerobot_piper_tutorial/1. 🤗 LeRobot:新增机械臂的一般流程.pdf
|
||||||
|
|
||||||
|
# Teleoperate
|
||||||
|
```bash
|
||||||
|
cd piper_scripts/
|
||||||
|
bash can_activate.sh can0 1000000
|
||||||
|
|
||||||
|
cd ..
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=piper \
|
||||||
|
--robot.inference_time=false \
|
||||||
|
--control.type=teleoperate
|
||||||
|
```
|
||||||
|
|
||||||
|
# Record
|
||||||
|
Set dataset root path
|
||||||
|
```bash
|
||||||
|
HF_USER=$PWD/data
|
||||||
|
echo $HF_USER
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=realman \
|
||||||
|
--robot.inference_time=false \
|
||||||
|
--control.type=record \
|
||||||
|
--control.fps=15 \
|
||||||
|
--control.single_task="move" \
|
||||||
|
--control.repo_id=maic/test \
|
||||||
|
--control.num_episodes=2 \
|
||||||
|
--control.warmup_time_s=2 \
|
||||||
|
--control.episode_time_s=10 \
|
||||||
|
--control.reset_time_s=10 \
|
||||||
|
--control.play_sounds=true \
|
||||||
|
--control.push_to_hub=false \
|
||||||
|
--control.display_data=true
|
||||||
|
```
|
||||||
|
|
||||||
|
Press right arrow -> at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
|
||||||
|
Press left arrow <- at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
|
||||||
|
Press escape ESC at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
|
||||||
|
|
||||||
|
# visualize
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/visualize_dataset.py \
|
||||||
|
--repo-id ${HF_USER}/test \
|
||||||
|
--episode-index 0
|
||||||
|
```
|
||||||
|
|
||||||
|
# Replay
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=piper \
|
||||||
|
--robot.inference_time=false \
|
||||||
|
--control.type=replay \
|
||||||
|
--control.fps=30 \
|
||||||
|
--control.repo_id=${HF_USER}/test \
|
||||||
|
--control.episode=0
|
||||||
|
```
|
||||||
|
|
||||||
|
# Caution
|
||||||
|
|
||||||
|
1. In lerobots/common/datasets/video_utils, the vcodec is set to **libopenh264**, please find your vcodec by **ffmpeg -codecs**
|
||||||
|
|
||||||
|
|
||||||
|
# Train
|
||||||
|
具体的训练流程见lerobot_piper_tutorial/2. 🤗 AutoDL训练.pdf
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--dataset.repo_id=${HF_USER}/jack \
|
||||||
|
--policy.type=act \
|
||||||
|
--output_dir=outputs/train/act_jack \
|
||||||
|
--job_name=act_jack \
|
||||||
|
--device=cuda \
|
||||||
|
--wandb.enable=true
|
||||||
|
```
|
||||||
|
|
||||||
|
# FT smolvla
|
||||||
|
python lerobot/scripts/train.py \
|
||||||
|
--dataset.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||||
|
--policy.path=lerobot/smolvla_base \
|
||||||
|
--output_dir=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||||
|
--job_name=smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--wandb.enable=false \
|
||||||
|
--steps=200000 \
|
||||||
|
--batch_size=16
|
||||||
|
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
还是使用control_robot.py中的record loop,配置 **--robot.inference_time=true** 可以将手柄移出。
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=realman \
|
||||||
|
--robot.inference_time=true \
|
||||||
|
--control.type=record \
|
||||||
|
--control.fps=30 \
|
||||||
|
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||||
|
--control.repo_id=maic/move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||||
|
--control.num_episodes=1 \
|
||||||
|
--control.warmup_time_s=2 \
|
||||||
|
--control.episode_time_s=30 \
|
||||||
|
--control.reset_time_s=10 \
|
||||||
|
--control.push_to_hub=false \
|
||||||
|
--control.policy.path=outputs/train/act_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/100000/pretrained_model
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python lerobot/scripts/control_robot.py \
|
||||||
|
--robot.type=realman \
|
||||||
|
--robot.inference_time=true \
|
||||||
|
--control.type=record \
|
||||||
|
--control.fps=30 \
|
||||||
|
--control.single_task="move the bottle into ultrasonic device with realman single" \
|
||||||
|
--control.repo_id=maic/eval_smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single \
|
||||||
|
--control.num_episodes=1 \
|
||||||
|
--control.warmup_time_s=2 \
|
||||||
|
--control.episode_time_s=60 \
|
||||||
|
--control.reset_time_s=10 \
|
||||||
|
--control.push_to_hub=false \
|
||||||
|
--control.policy.path=outputs/train/smolvla_move_the_bottle_into_ultrasonic_device_with_realman_single/checkpoints/160000/pretrained_model \
|
||||||
|
--control.display_data=true
|
||||||
|
```
|
||||||
31
realman_src/dual_arm_connect_test.py
Normal file
31
realman_src/dual_arm_connect_test.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
from Robotic_Arm.rm_robot_interface import *
|
||||||
|
|
||||||
|
armleft = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||||
|
armright = RoboticArm()
|
||||||
|
|
||||||
|
|
||||||
|
lefthandle = armleft.rm_create_robot_arm("169.254.128.18", 8080)
|
||||||
|
print("机械臂ID:", lefthandle.id)
|
||||||
|
righthandle = armright.rm_create_robot_arm("169.254.128.19", 8080)
|
||||||
|
print("机械臂ID:", righthandle.id)
|
||||||
|
|
||||||
|
# software_info = armleft.rm_get_arm_software_info()
|
||||||
|
# if software_info[0] == 0:
|
||||||
|
# print("\n================== Arm Software Information ==================")
|
||||||
|
# print("Arm Model: ", software_info[1]['product_version'])
|
||||||
|
# print("Algorithm Library Version: ", software_info[1]['algorithm_info']['version'])
|
||||||
|
# print("Control Layer Software Version: ", software_info[1]['ctrl_info']['version'])
|
||||||
|
# print("Dynamics Version: ", software_info[1]['dynamic_info']['model_version'])
|
||||||
|
# print("Planning Layer Software Version: ", software_info[1]['plan_info']['version'])
|
||||||
|
# print("==============================================================\n")
|
||||||
|
# else:
|
||||||
|
# print("\nFailed to get arm software information, Error code: ", software_info[0], "\n")
|
||||||
|
|
||||||
|
print("Left: ", armleft.rm_get_current_arm_state())
|
||||||
|
print("Left: ", armleft.rm_get_arm_all_state())
|
||||||
|
armleft.rm_movej_p()
|
||||||
|
# print("Right: ", armright.rm_get_current_arm_state())
|
||||||
|
|
||||||
|
|
||||||
|
# 断开所有连接,销毁线程
|
||||||
|
RoboticArm.rm_destory()
|
||||||
15
realman_src/movep_canfd.py
Normal file
15
realman_src/movep_canfd.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from Robotic_Arm.rm_robot_interface import *
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 实例化RoboticArm类
|
||||||
|
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
|
||||||
|
# 创建机械臂连接,打印连接id
|
||||||
|
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
|
||||||
|
print(handle.id)
|
||||||
|
|
||||||
|
print(arm.rm_movep_follow([-0.330512, 0.255993, -0.161205, 3.141, 0.0, -1.57]))
|
||||||
|
time.sleep(2)
|
||||||
|
# print(arm.rm_movep_follow([0.3, 0, 0.3, 3.14, 0, 0]))
|
||||||
|
# time.sleep(2)
|
||||||
|
|
||||||
|
arm.rm_delete_robot_arm()
|
||||||
0
realman_src/realman_aloha/__init__.py
Normal file
0
realman_src/realman_aloha/__init__.py
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_camera/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pt
|
||||||
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/README.md
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
0
realman_src/realman_aloha/shadow_camera/__init__.py
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
33
realman_src/realman_aloha/shadow_camera/pyproject.toml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "shadow_camera"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "camera class, currently includes realsense"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||||
|
license = "MIT"
|
||||||
|
#include = ["realman_vision/pytransform/_pytransform.so",]
|
||||||
|
classifiers = [
|
||||||
|
"Operating System :: POSIX :: Linux amd64",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.9"
|
||||||
|
numpy = ">=2.0.1"
|
||||||
|
opencv-python = ">=4.10.0.84"
|
||||||
|
pyrealsense2 = ">=2.55.1.6486"
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||||
|
pytest = ">=8.3"
|
||||||
|
black = ">=24.10.0"
|
||||||
|
|
||||||
|
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.8.4"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCamera(metaclass=ABCMeta):
|
||||||
|
"""摄像头基类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start_camera(self):
|
||||||
|
"""启动相机"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def stop_camera(self):
|
||||||
|
"""停止相机"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_resolution(self, resolution_width, resolution_height):
|
||||||
|
"""设置相机彩色图像分辨率"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def set_frame_rate(self, fps):
|
||||||
|
"""设置相机彩色图像帧率"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def read_frame(self):
|
||||||
|
"""读取一帧彩色图像和深度图像"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_camera_intrinsics(self):
|
||||||
|
"""获取彩色图像和深度图像的内参"""
|
||||||
|
pass
|
||||||
Binary file not shown.
@@ -0,0 +1,38 @@
|
|||||||
|
from shadow_camera import base_camera
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
class OpenCVCamera(base_camera.BaseCamera):
|
||||||
|
"""基于OpenCV的摄像头类"""
|
||||||
|
|
||||||
|
def __init__(self, device_id=0):
|
||||||
|
"""初始化视频捕获
|
||||||
|
|
||||||
|
参数:
|
||||||
|
device_id: 摄像头设备ID
|
||||||
|
"""
|
||||||
|
self.cap = cv2.VideoCapture(device_id)
|
||||||
|
|
||||||
|
def get_frame(self):
|
||||||
|
"""获取当前帧
|
||||||
|
|
||||||
|
返回:
|
||||||
|
frame: 当前帧的图像数据,取不到时返回None
|
||||||
|
"""
|
||||||
|
ret, frame = self.cap.read()
|
||||||
|
return frame if ret else None
|
||||||
|
|
||||||
|
def get_frame_info(self):
|
||||||
|
"""获取当前帧信息
|
||||||
|
|
||||||
|
返回:
|
||||||
|
dict: 帧信息字典
|
||||||
|
"""
|
||||||
|
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||||
|
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||||
|
channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS))
|
||||||
|
|
||||||
|
return {
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'channels': channels
|
||||||
|
}
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,280 @@
|
|||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import pyrealsense2 as rs
|
||||||
|
import base_camera
|
||||||
|
|
||||||
|
# 设置日志配置
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RealSenseCamera(base_camera.BaseCamera):
|
||||||
|
"""Intel RealSense相机类"""
|
||||||
|
|
||||||
|
def __init__(self, serial_num=None, is_depth_frame=False):
|
||||||
|
"""
|
||||||
|
初始化相机对象
|
||||||
|
:param serial_num: 相机序列号,默认为None
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._color_resolution = [640, 480]
|
||||||
|
self._depth_resolution = [640, 480]
|
||||||
|
self._color_frames_rate = 30
|
||||||
|
self._depth_frames_rate = 15
|
||||||
|
self.timestamp = 0
|
||||||
|
self.color_timestamp = 0
|
||||||
|
self.depth_timestamp = 0
|
||||||
|
self._colorizer = rs.colorizer()
|
||||||
|
self._config = rs.config()
|
||||||
|
self.is_depth_frame = is_depth_frame
|
||||||
|
self.camera_on = False
|
||||||
|
self.serial_num = serial_num
|
||||||
|
|
||||||
|
def get_serial_num(self):
|
||||||
|
serial_num = {}
|
||||||
|
context = rs.context()
|
||||||
|
devices = context.query_devices() # 获取所有设备
|
||||||
|
if len(context.devices) > 0:
|
||||||
|
for i, device in enumerate(devices):
|
||||||
|
serial_num[i] = device.get_info(rs.camera_info.serial_number)
|
||||||
|
|
||||||
|
logging.info(f"Detected serial numbers: {serial_num}")
|
||||||
|
return serial_num
|
||||||
|
|
||||||
|
def _set_config(self):
|
||||||
|
if self.serial_num is not None:
|
||||||
|
logging.info(f"Setting device with serial number: {self.serial_num}")
|
||||||
|
self._config.enable_device(self.serial_num)
|
||||||
|
|
||||||
|
self._config.enable_stream(
|
||||||
|
rs.stream.color,
|
||||||
|
self._color_resolution[0],
|
||||||
|
self._color_resolution[1],
|
||||||
|
rs.format.rgb8,
|
||||||
|
self._color_frames_rate,
|
||||||
|
)
|
||||||
|
if self.is_depth_frame:
|
||||||
|
self._config.enable_stream(
|
||||||
|
rs.stream.depth,
|
||||||
|
self._depth_resolution[0],
|
||||||
|
self._depth_resolution[1],
|
||||||
|
rs.format.z16,
|
||||||
|
self._depth_frames_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
def start_camera(self):
|
||||||
|
"""
|
||||||
|
启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参
|
||||||
|
"""
|
||||||
|
self._pipeline = rs.pipeline()
|
||||||
|
if self.is_depth_frame:
|
||||||
|
self.point_cloud = rs.pointcloud()
|
||||||
|
self._align = rs.align(rs.stream.color)
|
||||||
|
self._set_config()
|
||||||
|
|
||||||
|
self.profile = self._pipeline.start(self._config)
|
||||||
|
|
||||||
|
if self.is_depth_frame:
|
||||||
|
self._depth_intrinsics = (
|
||||||
|
self.profile.get_stream(rs.stream.depth)
|
||||||
|
.as_video_stream_profile()
|
||||||
|
.get_intrinsics()
|
||||||
|
)
|
||||||
|
|
||||||
|
self._color_intrinsics = (
|
||||||
|
self.profile.get_stream(rs.stream.color)
|
||||||
|
.as_video_stream_profile()
|
||||||
|
.get_intrinsics()
|
||||||
|
)
|
||||||
|
self.camera_on = True
|
||||||
|
logging.info("Camera started successfully")
|
||||||
|
logging.info(
|
||||||
|
f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def stop_camera(self):
|
||||||
|
"""
|
||||||
|
停止相机
|
||||||
|
"""
|
||||||
|
self._pipeline.stop()
|
||||||
|
self.camera_on = False
|
||||||
|
logging.info("Camera stopped")
|
||||||
|
|
||||||
|
def set_resolution(self, color_resolution, depth_resolution):
|
||||||
|
self._color_resolution = color_resolution
|
||||||
|
self._depth_resolution = depth_resolution
|
||||||
|
logging.info(
|
||||||
|
"Optional color resolution:"
|
||||||
|
"[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]"
|
||||||
|
"[848, 480] [960, 540] [1280, 720] [1920, 1080]"
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"Optional depth resolution:"
|
||||||
|
"[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]"
|
||||||
|
"[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]"
|
||||||
|
)
|
||||||
|
logging.info(f"Set color resolution to: {color_resolution}")
|
||||||
|
logging.info(f"Set depth resolution to: {depth_resolution}")
|
||||||
|
|
||||||
|
def set_frame_rate(self, color_fps, depth_fps):
|
||||||
|
self._color_frames_rate = color_fps
|
||||||
|
self._depth_frames_rate = depth_fps
|
||||||
|
logging.info("Optional color fps: 6 15 30 60 ")
|
||||||
|
logging.info("Optional depth fps: 6 15 30 60 90 100 300")
|
||||||
|
logging.info(f"Set color FPS to: {color_fps}")
|
||||||
|
logging.info(f"Set depth FPS to: {depth_fps}")
|
||||||
|
|
||||||
|
# TODO: 调节白平衡进行补偿
|
||||||
|
# def set_exposure(self, exposure):
|
||||||
|
|
||||||
|
def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||||
|
"""
|
||||||
|
读取一帧彩色图像和深度图像
|
||||||
|
:return: 彩色图像和深度图像的NumPy数组
|
||||||
|
"""
|
||||||
|
while not self.camera_on:
|
||||||
|
time.sleep(0.5)
|
||||||
|
color_image = None
|
||||||
|
depth_image = None
|
||||||
|
colorized_depth = None
|
||||||
|
point_cloud = None
|
||||||
|
try:
|
||||||
|
frames = self._pipeline.wait_for_frames()
|
||||||
|
if is_color:
|
||||||
|
color_frame = frames.get_color_frame()
|
||||||
|
color_image = np.asanyarray(color_frame.get_data())
|
||||||
|
else:
|
||||||
|
color_image = None
|
||||||
|
|
||||||
|
if is_depth:
|
||||||
|
depth_frame = frames.get_depth_frame()
|
||||||
|
depth_image = np.asanyarray(depth_frame.get_data())
|
||||||
|
else:
|
||||||
|
depth_image = None
|
||||||
|
|
||||||
|
colorized_depth = (
|
||||||
|
np.asanyarray(self._colorizer.colorize(depth_frame).get_data())
|
||||||
|
if is_colorized_depth
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
point_cloud = (
|
||||||
|
np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices())
|
||||||
|
if is_point_cloud
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||||
|
self.color_timestamp = color_frame.get_timestamp()
|
||||||
|
if self.is_depth_frame:
|
||||||
|
self.depth_timestamp = depth_frame.get_timestamp()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(e)
|
||||||
|
if "Frame didn't arrive within 5000" in str(e):
|
||||||
|
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||||
|
self.stop_camera()
|
||||||
|
self.start_camera()
|
||||||
|
|
||||||
|
return color_image, depth_image, colorized_depth, point_cloud
|
||||||
|
|
||||||
|
def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
|
||||||
|
"""
|
||||||
|
读取一帧对齐的彩色图像和深度图像
|
||||||
|
:return: 彩色图像和深度图像的NumPy数组
|
||||||
|
"""
|
||||||
|
while not self.camera_on:
|
||||||
|
time.sleep(0.5)
|
||||||
|
try:
|
||||||
|
frames = self._pipeline.wait_for_frames()
|
||||||
|
aligned_frames = self._align.process(frames)
|
||||||
|
aligned_color_frame = aligned_frames.get_color_frame()
|
||||||
|
self._aligned_depth_frame = aligned_frames.get_depth_frame()
|
||||||
|
|
||||||
|
color_image = np.asanyarray(aligned_color_frame.get_data())
|
||||||
|
depth_image = np.asanyarray(self._aligned_depth_frame.get_data())
|
||||||
|
colorized_depth = (
|
||||||
|
np.asanyarray(
|
||||||
|
self._colorizer.colorize(self._aligned_depth_frame).get_data()
|
||||||
|
)
|
||||||
|
if is_colorized_depth
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_point_cloud:
|
||||||
|
points = self.point_cloud.calculate(self._aligned_depth_frame)
|
||||||
|
# 将元组数据转换为 NumPy 数组
|
||||||
|
point_cloud = np.array(
|
||||||
|
[[point[0], point[1], point[2]] for point in points.get_vertices()]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
point_cloud = None
|
||||||
|
|
||||||
|
# 获取时间戳单位为ms,对齐后color时间戳 > depth = aligned,选择color
|
||||||
|
self.timestamp = aligned_color_frame.get_timestamp()
|
||||||
|
|
||||||
|
return color_image, depth_image, colorized_depth, point_cloud
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if "Frame didn't arrive within 5000" in str(e):
|
||||||
|
logging.warning("Frame didn't arrive within 5000ms, resetting device")
|
||||||
|
self.stop_camera()
|
||||||
|
self.start_camera()
|
||||||
|
# device = self.profile.get_device()
|
||||||
|
# device.hardware_reset()
|
||||||
|
|
||||||
|
def get_camera_intrinsics(self):
|
||||||
|
"""
|
||||||
|
获取彩色图像和深度图像的内参信息
|
||||||
|
:return: 彩色图像和深度图像的内参信息
|
||||||
|
"""
|
||||||
|
# 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs
|
||||||
|
logging.info("Getting camera intrinsics")
|
||||||
|
logging.info(
|
||||||
|
"Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs"
|
||||||
|
)
|
||||||
|
return self._color_intrinsics, self._depth_intrinsics
|
||||||
|
|
||||||
|
def get_3d_camera_coordinate(self, depth_pixel, align=False):
|
||||||
|
"""
|
||||||
|
获取深度相机坐标系下的三维坐标
|
||||||
|
:param depth_pixel:深度像素坐标
|
||||||
|
:param align: 是否对齐
|
||||||
|
|
||||||
|
:return: 深度值和相机坐标
|
||||||
|
"""
|
||||||
|
if not hasattr(self, "_aligned_depth_frame"):
|
||||||
|
raise AttributeError(
|
||||||
|
"Aligned depth frame not set. Call read_align_frame() first."
|
||||||
|
)
|
||||||
|
|
||||||
|
distance = self._aligned_depth_frame.get_distance(
|
||||||
|
depth_pixel[0], depth_pixel[1]
|
||||||
|
)
|
||||||
|
intrinsics = self._color_intrinsics if align else self._depth_intrinsics
|
||||||
|
camera_coordinate = rs.rs2_deproject_pixel_to_point(
|
||||||
|
intrinsics, depth_pixel, distance
|
||||||
|
)
|
||||||
|
return distance, camera_coordinate
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
camera = RealSenseCamera(is_depth_frame=False)
|
||||||
|
camera.get_serial_num()
|
||||||
|
camera.start_camera()
|
||||||
|
# camera.set_frame_rate(60, 60)
|
||||||
|
color_image, depth_image, colorized_depth, point_cloud = camera.read_frame()
|
||||||
|
camera.stop_camera()
|
||||||
|
logging.info(f"Color image shape: {color_image.shape}")
|
||||||
|
# logging.info(f"Depth image shape: {depth_image.shape}")
|
||||||
|
# logging.info(f"Colorized depth image shape: {colorized_depth.shape}")
|
||||||
|
# logging.info(f"Point cloud shape: {point_cloud.shape}")
|
||||||
|
logging.info(f"Color timestamp: {camera.timestamp}")
|
||||||
|
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||||
|
logging.info(f"Color timestamp: {camera.color_timestamp}")
|
||||||
|
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
|
||||||
|
logging.info("Test passed")
|
||||||
@@ -0,0 +1,101 @@
|
|||||||
|
import pyrealsense2 as rs
|
||||||
|
import numpy as np
|
||||||
|
import h5py
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
import keyboard # 用于监听键盘输入
|
||||||
|
|
||||||
|
# 全局变量
|
||||||
|
is_recording = False # 标志位,控制录制状态
|
||||||
|
color_images = [] # 存储彩色图像
|
||||||
|
depth_images = [] # 存储深度图像
|
||||||
|
timestamps = [] # 存储时间戳
|
||||||
|
|
||||||
|
# 配置D435相机
|
||||||
|
def configure_camera():
|
||||||
|
pipeline = rs.pipeline()
|
||||||
|
config = rs.config()
|
||||||
|
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流
|
||||||
|
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流
|
||||||
|
pipeline.start(config)
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
# 监听键盘输入,控制录制状态
|
||||||
|
def listen_for_keyboard():
|
||||||
|
global is_recording
|
||||||
|
while True:
|
||||||
|
if keyboard.is_pressed('s'): # 按下 's' 开始录制
|
||||||
|
is_recording = True
|
||||||
|
print("Recording started.")
|
||||||
|
time.sleep(0.5) # 防止重复触发
|
||||||
|
elif keyboard.is_pressed('q'): # 按下 'q' 停止录制
|
||||||
|
is_recording = False
|
||||||
|
print("Recording stopped.")
|
||||||
|
time.sleep(0.5) # 防止重复触发
|
||||||
|
elif keyboard.is_pressed('e'): # 按下 'e' 退出程序
|
||||||
|
print("Exiting program.")
|
||||||
|
exit()
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# 采集图像数据
|
||||||
|
def capture_frames(pipeline):
|
||||||
|
global is_recording, color_images, depth_images, timestamps
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if is_recording:
|
||||||
|
frames = pipeline.wait_for_frames()
|
||||||
|
color_frame = frames.get_color_frame()
|
||||||
|
depth_frame = frames.get_depth_frame()
|
||||||
|
|
||||||
|
if not color_frame or not depth_frame:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取当前时间戳
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
# 将图像转换为numpy数组
|
||||||
|
color_image = np.asanyarray(color_frame.get_data())
|
||||||
|
depth_image = np.asanyarray(depth_frame.get_data())
|
||||||
|
|
||||||
|
# 存储数据
|
||||||
|
color_images.append(color_image)
|
||||||
|
depth_images.append(depth_image)
|
||||||
|
timestamps.append(timestamp)
|
||||||
|
|
||||||
|
print(f"Captured frame at {timestamp}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
time.sleep(0.1) # 如果未录制,等待一段时间
|
||||||
|
|
||||||
|
finally:
|
||||||
|
pipeline.stop()
|
||||||
|
|
||||||
|
# 保存为HDF5文件
|
||||||
|
def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"):
|
||||||
|
with h5py.File(filename, "w") as f:
|
||||||
|
f.create_dataset("color_images", data=np.array(color_images), compression="gzip")
|
||||||
|
f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip")
|
||||||
|
f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip")
|
||||||
|
print(f"Data saved to {filename}")
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
global is_recording, color_images, depth_images, timestamps
|
||||||
|
|
||||||
|
# 启动键盘监听线程
|
||||||
|
keyboard_thread = threading.Thread(target=listen_for_keyboard)
|
||||||
|
keyboard_thread.daemon = True
|
||||||
|
keyboard_thread.start()
|
||||||
|
|
||||||
|
# 配置相机
|
||||||
|
pipeline = configure_camera()
|
||||||
|
|
||||||
|
# 开始采集图像
|
||||||
|
capture_frames(pipeline)
|
||||||
|
|
||||||
|
# 录制结束后保存数据
|
||||||
|
if color_images and depth_images and timestamps:
|
||||||
|
save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
152
realman_src/realman_aloha/shadow_camera/test/test_camera.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from os import path
|
||||||
|
import pyrealsense2 as rs
|
||||||
|
from shadow_camera import realsense
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_camera():
|
||||||
|
camera = realsense.RealSenseCamera('241122071186')
|
||||||
|
camera.start_camera()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# result = camera.read_align_frame()
|
||||||
|
# if result is None:
|
||||||
|
# print('is None')
|
||||||
|
# continue
|
||||||
|
# start_time = time.time()
|
||||||
|
color_image, depth_image, colorized_depth, vtx = camera.read_frame()
|
||||||
|
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
print(f"color_image: {color_image.shape}")
|
||||||
|
# print(f"Time: {end_time - start_time}")
|
||||||
|
cv2.imshow("bgr_image", color_image)
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
|
break
|
||||||
|
camera.stop_camera()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_serial_num():
|
||||||
|
camera = realsense.RealSenseCamera()
|
||||||
|
device = camera.get_serial_num()
|
||||||
|
|
||||||
|
|
||||||
|
class CameraCapture:
|
||||||
|
def __init__(self, camera_serial_num=None, save_dir="./save"):
|
||||||
|
self._camera_serial_num = camera_serial_num
|
||||||
|
self._color_save_dir = path.join(save_dir, "color")
|
||||||
|
self._depth_save_dir = path.join(save_dir, "depth")
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
os.makedirs(self._color_save_dir, exist_ok=True)
|
||||||
|
os.makedirs(self._depth_save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def get_serial_num(self):
|
||||||
|
self._camera_serial_num = {}
|
||||||
|
camera_names = ["left", "right", "head", "table"]
|
||||||
|
context = rs.context()
|
||||||
|
devices = context.query_devices() # 获取所有设备
|
||||||
|
if len(context.devices) > 0:
|
||||||
|
for i, device in enumerate(devices):
|
||||||
|
self._camera_serial_num[camera_names[i]] = device.get_info(
|
||||||
|
rs.camera_info.serial_number
|
||||||
|
)
|
||||||
|
print(self._camera_serial_num)
|
||||||
|
|
||||||
|
return self._camera_serial_num
|
||||||
|
|
||||||
|
def start_camera(self):
|
||||||
|
if self._camera_serial_num is None:
|
||||||
|
self.get_serial_num()
|
||||||
|
self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"])
|
||||||
|
self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"])
|
||||||
|
self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"])
|
||||||
|
|
||||||
|
self._camera_left.start_camera()
|
||||||
|
self._camera_right.start_camera()
|
||||||
|
self._camera_head.start_camera()
|
||||||
|
|
||||||
|
def stop_camera(self):
|
||||||
|
self._camera_left.stop_camera()
|
||||||
|
self._camera_right.stop_camera()
|
||||||
|
self._camera_head.stop_camera()
|
||||||
|
|
||||||
|
def _save_datas(self, timestamp, color_image, depth_image, camera_name):
|
||||||
|
color_filename = path.join(
|
||||||
|
self._color_save_dir, f"{timestamp}" + camera_name + ".jpg"
|
||||||
|
)
|
||||||
|
depth_filename = path.join(
|
||||||
|
self._depth_save_dir, f"{timestamp}" + camera_name + ".png"
|
||||||
|
)
|
||||||
|
cv2.imwrite(color_filename, color_image)
|
||||||
|
cv2.imwrite(depth_filename, depth_image)
|
||||||
|
|
||||||
|
def capture_images(self):
|
||||||
|
while True:
|
||||||
|
(
|
||||||
|
color_image_left,
|
||||||
|
depth_image_left,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._camera_left.read_align_frame()
|
||||||
|
(
|
||||||
|
color_image_right,
|
||||||
|
depth_image_right,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self._camera_right.read_align_frame()
|
||||||
|
(
|
||||||
|
color_image_head,
|
||||||
|
depth_image_head,
|
||||||
|
_,
|
||||||
|
point_cloud3,
|
||||||
|
) = self._camera_head.read_align_frame()
|
||||||
|
|
||||||
|
bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR)
|
||||||
|
bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR)
|
||||||
|
bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
timestamp = time.time() * 1000
|
||||||
|
|
||||||
|
cv2.imshow("Camera left", bgr_color_image_left)
|
||||||
|
cv2.imshow("Camera right", bgr_color_image_right)
|
||||||
|
cv2.imshow("Camera head", bgr_color_image_head)
|
||||||
|
|
||||||
|
# self._save_datas(
|
||||||
|
# timestamp, bgr_color_image_left, depth_image_left, "left"
|
||||||
|
# )
|
||||||
|
# self._save_datas(
|
||||||
|
# timestamp, bgr_color_image_right, depth_image_right, "right"
|
||||||
|
# )
|
||||||
|
# self._save_datas(
|
||||||
|
# timestamp, bgr_color_image_head, depth_image_head, "head"
|
||||||
|
# )
|
||||||
|
|
||||||
|
if cv2.waitKey(1) & 0xFF == ord("q"):
|
||||||
|
break
|
||||||
|
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#test_camera()
|
||||||
|
test_get_serial_num()
|
||||||
|
"""
|
||||||
|
输入相机序列号制定左右相机:
|
||||||
|
dict:{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'}
|
||||||
|
保存路径:
|
||||||
|
str:./save
|
||||||
|
输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save
|
||||||
|
"""
|
||||||
|
|
||||||
|
# capture = CameraCapture()
|
||||||
|
# capture.get_serial_num()
|
||||||
|
# test_get_serial_num()
|
||||||
|
|
||||||
|
# capture.start_camera()
|
||||||
|
# capture.capture_images()
|
||||||
|
# capture.stop_camera()
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import pytest
|
||||||
|
import pyrealsense2 as rs
|
||||||
|
from shadow_camera.realsense import RealSenseCamera
|
||||||
|
|
||||||
|
|
||||||
|
class TestRealSenseCamera:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_camera(self):
|
||||||
|
self.camera = RealSenseCamera()
|
||||||
|
|
||||||
|
def test_get_serial_num(self):
|
||||||
|
serial_nums = self.camera.get_serial_num()
|
||||||
|
assert isinstance(serial_nums, dict)
|
||||||
|
assert len(serial_nums) > 0
|
||||||
|
|
||||||
|
def test_start_stop_camera(self):
|
||||||
|
self.camera.start_camera()
|
||||||
|
assert self.camera.camera_on is True
|
||||||
|
self.camera.stop_camera()
|
||||||
|
assert self.camera.camera_on is False
|
||||||
|
|
||||||
|
def test_set_resolution(self):
|
||||||
|
color_resolution = [1280, 720]
|
||||||
|
depth_resolution = [1280, 720]
|
||||||
|
self.camera.set_resolution(color_resolution, depth_resolution)
|
||||||
|
assert self.camera._color_resolution == color_resolution
|
||||||
|
assert self.camera._depth_resolution == depth_resolution
|
||||||
|
|
||||||
|
def test_set_frame_rate(self):
|
||||||
|
color_fps = 60
|
||||||
|
depth_fps = 60
|
||||||
|
self.camera.set_frame_rate(color_fps, depth_fps)
|
||||||
|
assert self.camera._color_frames_rate == color_fps
|
||||||
|
assert self.camera._depth_frames_rate == depth_fps
|
||||||
|
|
||||||
|
def test_read_frame(self):
|
||||||
|
self.camera.start_camera()
|
||||||
|
color_image, depth_image, colorized_depth, point_cloud = (
|
||||||
|
self.camera.read_frame()
|
||||||
|
)
|
||||||
|
assert color_image is not None
|
||||||
|
assert depth_image is not None
|
||||||
|
self.camera.stop_camera()
|
||||||
|
|
||||||
|
def test_read_align_frame(self):
|
||||||
|
self.camera.start_camera()
|
||||||
|
color_image, depth_image, colorized_depth, point_cloud = (
|
||||||
|
self.camera.read_align_frame()
|
||||||
|
)
|
||||||
|
assert color_image is not None
|
||||||
|
assert depth_image is not None
|
||||||
|
self.camera.stop_camera()
|
||||||
|
|
||||||
|
def test_get_camera_intrinsics(self):
|
||||||
|
self.camera.start_camera()
|
||||||
|
color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics()
|
||||||
|
assert color_intrinsics is not None
|
||||||
|
assert depth_intrinsics is not None
|
||||||
|
self.camera.stop_camera()
|
||||||
|
|
||||||
|
def test_get_3d_camera_coordinate(self):
|
||||||
|
self.camera.start_camera()
|
||||||
|
# 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置
|
||||||
|
self.camera.read_align_frame()
|
||||||
|
depth_pixel = [320, 240]
|
||||||
|
distance, camera_coordinate = self.camera.get_3d_camera_coordinate(
|
||||||
|
depth_pixel, align=True
|
||||||
|
)
|
||||||
|
assert distance > 0
|
||||||
|
assert len(camera_coordinate) == 3
|
||||||
|
self.camera.stop_camera()
|
||||||
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_act/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
__pycache__/
|
||||||
|
build/
|
||||||
|
devel/
|
||||||
|
dist/
|
||||||
|
data/
|
||||||
|
.catkin_workspace
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pt
|
||||||
|
.vscode/
|
||||||
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
89
realman_src/realman_aloha/shadow_rm_act/README.md
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# ACT: Action Chunking with Transformers
|
||||||
|
|
||||||
|
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
|
||||||
|
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
|
||||||
|
|
||||||
|
#### Project Website: https://tonyzhaozh.github.io/aloha/
|
||||||
|
|
||||||
|
This repo contains the implementation of ACT, together with 2 simulated environments:
|
||||||
|
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
|
||||||
|
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
|
||||||
|
|
||||||
|
### Updates:
|
||||||
|
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
|
||||||
|
|
||||||
|
|
||||||
|
### Repo Structure
|
||||||
|
- ``imitate_episodes.py`` Train and Evaluate ACT
|
||||||
|
- ``policy.py`` An adaptor for ACT policy
|
||||||
|
- ``detr`` Model definitions of ACT, modified from DETR
|
||||||
|
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
|
||||||
|
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
|
||||||
|
- ``scripted_policy.py`` Scripted policies for sim environments
|
||||||
|
- ``constants.py`` Constants shared across files
|
||||||
|
- ``utils.py`` Utils such as data loading and helper functions
|
||||||
|
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
|
||||||
|
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
conda create -n aloha python=3.8.10
|
||||||
|
conda activate aloha
|
||||||
|
pip install torchvision
|
||||||
|
pip install torch
|
||||||
|
pip install pyquaternion
|
||||||
|
pip install pyyaml
|
||||||
|
pip install rospkg
|
||||||
|
pip install pexpect
|
||||||
|
pip install mujoco==2.3.7
|
||||||
|
pip install dm_control==1.0.14
|
||||||
|
pip install opencv-python
|
||||||
|
pip install matplotlib
|
||||||
|
pip install einops
|
||||||
|
pip install packaging
|
||||||
|
pip install h5py
|
||||||
|
pip install ipython
|
||||||
|
cd act/detr && pip install -e .
|
||||||
|
|
||||||
|
### Example Usages
|
||||||
|
|
||||||
|
To set up a new terminal, run:
|
||||||
|
|
||||||
|
conda activate aloha
|
||||||
|
cd <path to act repo>
|
||||||
|
|
||||||
|
### Simulated experiments
|
||||||
|
|
||||||
|
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
|
||||||
|
To generated 50 episodes of scripted data, run:
|
||||||
|
|
||||||
|
python3 record_sim_episodes.py \
|
||||||
|
--task_name sim_transfer_cube_scripted \
|
||||||
|
--dataset_dir <data save dir> \
|
||||||
|
--num_episodes 50
|
||||||
|
|
||||||
|
To can add the flag ``--onscreen_render`` to see real-time rendering.
|
||||||
|
To visualize the episode after it is collected, run
|
||||||
|
|
||||||
|
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
|
||||||
|
|
||||||
|
To train ACT:
|
||||||
|
|
||||||
|
# Transfer Cube task
|
||||||
|
python3 imitate_episodes.py \
|
||||||
|
--task_name sim_transfer_cube_scripted \
|
||||||
|
--ckpt_dir <ckpt dir> \
|
||||||
|
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
|
||||||
|
--num_epochs 2000 --lr 1e-5 \
|
||||||
|
--seed 0
|
||||||
|
|
||||||
|
|
||||||
|
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
|
||||||
|
The success rate should be around 90% for transfer cube, and around 50% for insertion.
|
||||||
|
To enable temporal ensembling, add flag ``--temporal_agg``.
|
||||||
|
Videos will be saved to ``<ckpt_dir>`` for each rollout.
|
||||||
|
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
|
||||||
|
|
||||||
|
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
|
||||||
|
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.
|
||||||
|
|
||||||
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
74
realman_src/realman_aloha/shadow_rm_act/config/config.yaml
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
robot_env: {
|
||||||
|
# TODO change the path to the correct one
|
||||||
|
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||||
|
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||||
|
arm_axis: 6,
|
||||||
|
head_camera: '215222076892',
|
||||||
|
bottom_camera: '215222076981',
|
||||||
|
left_camera: '152122078151',
|
||||||
|
right_camera: '152122073489',
|
||||||
|
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
|
||||||
|
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
|
||||||
|
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||||
|
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||||
|
}
|
||||||
|
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
|
||||||
|
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
|
||||||
|
# checkpoint_name: 'policy_best.ckpt'
|
||||||
|
checkpoint_name: 'policy_9500.ckpt'
|
||||||
|
state_dim: 14
|
||||||
|
save_episode: True
|
||||||
|
num_rollouts: 50 #训练期间要收集的 rollout(轨迹)数量
|
||||||
|
real_robot: True
|
||||||
|
policy_class: 'ACT'
|
||||||
|
onscreen_render: False
|
||||||
|
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
|
||||||
|
episode_len: 300 #episode 的最大长度(时间步数)。
|
||||||
|
task_name: 'aloha_01_11.28'
|
||||||
|
temporal_agg: False #是否使用时间聚合
|
||||||
|
batch_size: 8 #训练期间每批的样本数。
|
||||||
|
seed: 1000 #随机种子。
|
||||||
|
chunk_size: 30 #用于处理序列的块大小
|
||||||
|
eval_every: 1 #每隔 eval_every 步评估一次模型。
|
||||||
|
num_steps: 10000 #训练的总步数。
|
||||||
|
validate_every: 1 #每隔 validate_every 步验证一次模型。
|
||||||
|
save_every: 500 #每隔 save_every 步保存一次检查点。
|
||||||
|
load_pretrain: False #是否加载预训练模型。
|
||||||
|
resume_ckpt_path:
|
||||||
|
name_filter: # TODO
|
||||||
|
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
|
||||||
|
stats_dir:
|
||||||
|
sample_weights:
|
||||||
|
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
|
||||||
|
|
||||||
|
policy_config: {
|
||||||
|
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
|
||||||
|
state_dim: 14, # Dimension of the state
|
||||||
|
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
|
||||||
|
lr_backbone: 1.0e-5,
|
||||||
|
masks: False, # If true, the model masks the non-visible pixels
|
||||||
|
backbone: 'resnet18',
|
||||||
|
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
|
||||||
|
dropout: 0.1, # Dropout applied in the transformer
|
||||||
|
nheads: 8,
|
||||||
|
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
|
||||||
|
enc_layers: 4, # Number of encoding layers in the transformer
|
||||||
|
dec_layers: 7, # Number of decoding layers in the transformer
|
||||||
|
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
|
||||||
|
num_queries: 30,
|
||||||
|
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
|
||||||
|
vq: False,
|
||||||
|
vq_class: none,
|
||||||
|
vq_dim: 64,
|
||||||
|
action_dim: 14,
|
||||||
|
no_encoder: False,
|
||||||
|
lr: 1.0e-5,
|
||||||
|
weight_decay: 1.0e-4,
|
||||||
|
kl_weight: 10,
|
||||||
|
|
||||||
|
# lr_drop: 200,
|
||||||
|
# clip_max_norm: 0.1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
267
realman_src/realman_aloha/shadow_rm_act/ee_sim_env.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import os
|
||||||
|
|
||||||
|
from constants import DT, XML_DIR, START_ARM_POSE
|
||||||
|
from constants import PUPPET_GRIPPER_POSITION_CLOSE
|
||||||
|
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||||
|
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||||
|
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||||
|
|
||||||
|
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from dm_control.suite import base
|
||||||
|
|
||||||
|
import IPython
|
||||||
|
e = IPython.embed
|
||||||
|
|
||||||
|
|
||||||
|
def make_ee_sim_env(task_name):
|
||||||
|
"""
|
||||||
|
Environment for simulated robot bi-manual manipulation, with end-effector control.
|
||||||
|
Action space: [left_arm_pose (7), # position and quaternion for end effector
|
||||||
|
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_pose (7), # position and quaternion for end effector
|
||||||
|
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||||
|
|
||||||
|
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||||
|
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_qpos (6), # absolute joint position
|
||||||
|
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||||
|
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||||
|
"""
|
||||||
|
if 'sim_transfer_cube' in task_name:
|
||||||
|
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
|
||||||
|
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||||
|
task = TransferCubeEETask(random=False)
|
||||||
|
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||||
|
n_sub_steps=None, flat_observation=False)
|
||||||
|
elif 'sim_insertion' in task_name:
|
||||||
|
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
|
||||||
|
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||||
|
task = InsertionEETask(random=False)
|
||||||
|
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||||
|
n_sub_steps=None, flat_observation=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return env
|
||||||
|
|
||||||
|
class BimanualViperXEETask(base.Task):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
|
||||||
|
def before_step(self, action, physics):
|
||||||
|
a_len = len(action) // 2
|
||||||
|
action_left = action[:a_len]
|
||||||
|
action_right = action[a_len:]
|
||||||
|
|
||||||
|
# set mocap position and quat
|
||||||
|
# left
|
||||||
|
np.copyto(physics.data.mocap_pos[0], action_left[:3])
|
||||||
|
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
|
||||||
|
# right
|
||||||
|
np.copyto(physics.data.mocap_pos[1], action_right[:3])
|
||||||
|
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
|
||||||
|
|
||||||
|
# set gripper
|
||||||
|
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
|
||||||
|
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
|
||||||
|
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
|
||||||
|
|
||||||
|
def initialize_robots(self, physics):
|
||||||
|
# reset joint position
|
||||||
|
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||||
|
|
||||||
|
# reset mocap to align with end effector
|
||||||
|
# to obtain these numbers:
|
||||||
|
# (1) make an ee_sim env and reset to the same start_pose
|
||||||
|
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
|
||||||
|
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
|
||||||
|
# repeat the same for right side
|
||||||
|
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
|
||||||
|
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
|
||||||
|
# right
|
||||||
|
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
|
||||||
|
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
|
||||||
|
|
||||||
|
# reset gripper control
|
||||||
|
close_gripper_control = np.array([
|
||||||
|
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
-PUPPET_GRIPPER_POSITION_CLOSE,
|
||||||
|
])
|
||||||
|
np.copyto(physics.data.ctrl, close_gripper_control)
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qpos(physics):
|
||||||
|
qpos_raw = physics.data.qpos.copy()
|
||||||
|
left_qpos_raw = qpos_raw[:8]
|
||||||
|
right_qpos_raw = qpos_raw[8:16]
|
||||||
|
left_arm_qpos = left_qpos_raw[:6]
|
||||||
|
right_arm_qpos = right_qpos_raw[:6]
|
||||||
|
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||||
|
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qvel(physics):
|
||||||
|
qvel_raw = physics.data.qvel.copy()
|
||||||
|
left_qvel_raw = qvel_raw[:8]
|
||||||
|
right_qvel_raw = qvel_raw[8:16]
|
||||||
|
left_arm_qvel = left_qvel_raw[:6]
|
||||||
|
right_arm_qvel = right_qvel_raw[:6]
|
||||||
|
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||||
|
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
# note: it is important to do .copy()
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['qpos'] = self.get_qpos(physics)
|
||||||
|
obs['qvel'] = self.get_qvel(physics)
|
||||||
|
obs['env_state'] = self.get_env_state(physics)
|
||||||
|
obs['images'] = dict()
|
||||||
|
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||||
|
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||||
|
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||||
|
# used in scripted policy to obtain starting pose
|
||||||
|
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
|
||||||
|
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
|
||||||
|
|
||||||
|
# used when replaying joint trajectory
|
||||||
|
obs['gripper_ctrl'] = physics.data.ctrl.copy()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TransferCubeEETask(BimanualViperXEETask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
self.initialize_robots(physics)
|
||||||
|
# randomize box position
|
||||||
|
cube_pose = sample_box_pose()
|
||||||
|
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
|
||||||
|
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
|
||||||
|
# print(f"randomized cube position to {cube_position}")
|
||||||
|
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether left gripper is holding the box
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_table = ("red_box", "table") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_right_gripper:
|
||||||
|
reward = 1
|
||||||
|
if touch_right_gripper and not touch_table: # lifted
|
||||||
|
reward = 2
|
||||||
|
if touch_left_gripper: # attempted transfer
|
||||||
|
reward = 3
|
||||||
|
if touch_left_gripper and not touch_table: # successful transfer
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
class InsertionEETask(BimanualViperXEETask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
self.initialize_robots(physics)
|
||||||
|
# randomize peg and socket position
|
||||||
|
peg_pose, socket_pose = sample_insertion_pose()
|
||||||
|
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
|
||||||
|
|
||||||
|
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
|
||||||
|
peg_start_idx = id2index(peg_start_id)
|
||||||
|
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
|
||||||
|
# print(f"randomized cube position to {cube_position}")
|
||||||
|
|
||||||
|
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
|
||||||
|
socket_start_idx = id2index(socket_start_id)
|
||||||
|
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
|
||||||
|
# print(f"randomized cube position to {cube_position}")
|
||||||
|
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether peg touches the pin
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||||
|
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||||
|
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||||
|
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
|
||||||
|
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||||
|
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||||
|
("socket-2", "table") in all_contact_pairs or \
|
||||||
|
("socket-3", "table") in all_contact_pairs or \
|
||||||
|
("socket-4", "table") in all_contact_pairs
|
||||||
|
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||||
|
("red_peg", "socket-2") in all_contact_pairs or \
|
||||||
|
("red_peg", "socket-3") in all_contact_pairs or \
|
||||||
|
("red_peg", "socket-4") in all_contact_pairs
|
||||||
|
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_left_gripper and touch_right_gripper: # touch both
|
||||||
|
reward = 1
|
||||||
|
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||||
|
reward = 2
|
||||||
|
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||||
|
reward = 3
|
||||||
|
if pin_touched: # successful insertion
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
36
realman_src/realman_aloha/shadow_rm_act/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "shadow_act"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Embodied data, ACT and other methods; training and verification function packages"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||||
|
license = "MIT"
|
||||||
|
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||||
|
classifiers = [
|
||||||
|
"Operating System :: POSIX :: Linux amd64",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.9"
|
||||||
|
wandb = ">=0.18.0"
|
||||||
|
einops = ">=0.8.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||||
|
pytest = ">=8.3"
|
||||||
|
black = ">=24.10.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.8.4"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
189
realman_src/realman_aloha/shadow_rm_act/record_sim_episodes.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
import time
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import h5py
|
||||||
|
|
||||||
|
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
|
||||||
|
from ee_sim_env import make_ee_sim_env
|
||||||
|
from sim_env import make_sim_env, BOX_POSE
|
||||||
|
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
|
||||||
|
|
||||||
|
import IPython
|
||||||
|
e = IPython.embed
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
"""
|
||||||
|
Generate demonstration data in simulation.
|
||||||
|
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
|
||||||
|
Replace the gripper joint positions with the commanded joint position.
|
||||||
|
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
|
||||||
|
Save this episode of data, and continue to next episode of data collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
task_name = args['task_name']
|
||||||
|
dataset_dir = args['dataset_dir']
|
||||||
|
num_episodes = args['num_episodes']
|
||||||
|
onscreen_render = args['onscreen_render']
|
||||||
|
inject_noise = False
|
||||||
|
render_cam_name = 'angle'
|
||||||
|
|
||||||
|
if not os.path.isdir(dataset_dir):
|
||||||
|
os.makedirs(dataset_dir, exist_ok=True)
|
||||||
|
|
||||||
|
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||||
|
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
|
||||||
|
if task_name == 'sim_transfer_cube_scripted':
|
||||||
|
policy_cls = PickAndTransferPolicy
|
||||||
|
elif task_name == 'sim_insertion_scripted':
|
||||||
|
policy_cls = InsertionPolicy
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
success = []
|
||||||
|
for episode_idx in range(num_episodes):
|
||||||
|
print(f'{episode_idx=}')
|
||||||
|
print('Rollout out EE space scripted policy')
|
||||||
|
# setup the environment
|
||||||
|
env = make_ee_sim_env(task_name)
|
||||||
|
ts = env.reset()
|
||||||
|
episode = [ts]
|
||||||
|
policy = policy_cls(inject_noise)
|
||||||
|
# setup plotting
|
||||||
|
if onscreen_render:
|
||||||
|
ax = plt.subplot()
|
||||||
|
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||||
|
plt.ion()
|
||||||
|
for step in range(episode_len):
|
||||||
|
action = policy(ts)
|
||||||
|
ts = env.step(action)
|
||||||
|
episode.append(ts)
|
||||||
|
if onscreen_render:
|
||||||
|
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||||
|
plt.pause(0.002)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||||
|
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
|
||||||
|
if episode_max_reward == env.task.max_reward:
|
||||||
|
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||||
|
else:
|
||||||
|
print(f"{episode_idx=} Failed")
|
||||||
|
|
||||||
|
joint_traj = [ts.observation['qpos'] for ts in episode]
|
||||||
|
# replace gripper pose with gripper control
|
||||||
|
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
|
||||||
|
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
|
||||||
|
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
|
||||||
|
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
|
||||||
|
joint[6] = left_ctrl
|
||||||
|
joint[6+7] = right_ctrl
|
||||||
|
|
||||||
|
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
|
||||||
|
|
||||||
|
# clear unused variables
|
||||||
|
del env
|
||||||
|
del episode
|
||||||
|
del policy
|
||||||
|
|
||||||
|
# setup the environment
|
||||||
|
print('Replaying joint commands')
|
||||||
|
env = make_sim_env(task_name)
|
||||||
|
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
|
||||||
|
ts = env.reset()
|
||||||
|
|
||||||
|
episode_replay = [ts]
|
||||||
|
# setup plotting
|
||||||
|
if onscreen_render:
|
||||||
|
ax = plt.subplot()
|
||||||
|
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
|
||||||
|
plt.ion()
|
||||||
|
for t in range(len(joint_traj)): # note: this will increase episode length by 1
|
||||||
|
action = joint_traj[t]
|
||||||
|
ts = env.step(action)
|
||||||
|
episode_replay.append(ts)
|
||||||
|
if onscreen_render:
|
||||||
|
plt_img.set_data(ts.observation['images'][render_cam_name])
|
||||||
|
plt.pause(0.02)
|
||||||
|
|
||||||
|
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
|
||||||
|
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
|
||||||
|
if episode_max_reward == env.task.max_reward:
|
||||||
|
success.append(1)
|
||||||
|
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||||
|
else:
|
||||||
|
success.append(0)
|
||||||
|
print(f"{episode_idx=} Failed")
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
"""
|
||||||
|
For each timestep:
|
||||||
|
observations
|
||||||
|
- images
|
||||||
|
- each_cam_name (480, 640, 3) 'uint8'
|
||||||
|
- qpos (14,) 'float64'
|
||||||
|
- qvel (14,) 'float64'
|
||||||
|
|
||||||
|
action (14,) 'float64'
|
||||||
|
"""
|
||||||
|
|
||||||
|
data_dict = {
|
||||||
|
'/observations/qpos': [],
|
||||||
|
'/observations/qvel': [],
|
||||||
|
'/action': [],
|
||||||
|
}
|
||||||
|
for cam_name in camera_names:
|
||||||
|
data_dict[f'/observations/images/{cam_name}'] = []
|
||||||
|
|
||||||
|
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
|
||||||
|
# truncate here to be consistent
|
||||||
|
joint_traj = joint_traj[:-1]
|
||||||
|
episode_replay = episode_replay[:-1]
|
||||||
|
|
||||||
|
# len(joint_traj) i.e. actions: max_timesteps
|
||||||
|
# len(episode_replay) i.e. time steps: max_timesteps + 1
|
||||||
|
max_timesteps = len(joint_traj)
|
||||||
|
while joint_traj:
|
||||||
|
action = joint_traj.pop(0)
|
||||||
|
ts = episode_replay.pop(0)
|
||||||
|
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||||
|
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||||
|
data_dict['/action'].append(action)
|
||||||
|
for cam_name in camera_names:
|
||||||
|
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||||
|
|
||||||
|
# HDF5
|
||||||
|
t0 = time.time()
|
||||||
|
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
|
||||||
|
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
|
||||||
|
root.attrs['sim'] = True
|
||||||
|
obs = root.create_group('observations')
|
||||||
|
image = obs.create_group('images')
|
||||||
|
for cam_name in camera_names:
|
||||||
|
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
|
||||||
|
chunks=(1, 480, 640, 3), )
|
||||||
|
# compression='gzip',compression_opts=2,)
|
||||||
|
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
|
||||||
|
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
|
||||||
|
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
|
||||||
|
action = root.create_dataset('action', (max_timesteps, 14))
|
||||||
|
|
||||||
|
for name, array in data_dict.items():
|
||||||
|
root[name][...] = array
|
||||||
|
print(f'Saving: {time.time() - t0:.1f} secs\n')
|
||||||
|
|
||||||
|
print(f'Saved to {dataset_dir}')
|
||||||
|
print(f'Success: {np.sum(success)} / {len(success)}')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
||||||
|
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
|
||||||
|
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
|
||||||
|
parser.add_argument('--onscreen_render', action='store_true')
|
||||||
|
|
||||||
|
main(vars(parser.parse_args()))
|
||||||
|
|
||||||
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
194
realman_src/realman_aloha/shadow_rm_act/scripted_policy.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from pyquaternion import Quaternion
|
||||||
|
|
||||||
|
from constants import SIM_TASK_CONFIGS
|
||||||
|
from ee_sim_env import make_ee_sim_env
|
||||||
|
|
||||||
|
import IPython
|
||||||
|
e = IPython.embed
|
||||||
|
|
||||||
|
|
||||||
|
class BasePolicy:
|
||||||
|
def __init__(self, inject_noise=False):
|
||||||
|
self.inject_noise = inject_noise
|
||||||
|
self.step_count = 0
|
||||||
|
self.left_trajectory = None
|
||||||
|
self.right_trajectory = None
|
||||||
|
|
||||||
|
def generate_trajectory(self, ts_first):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def interpolate(curr_waypoint, next_waypoint, t):
|
||||||
|
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
|
||||||
|
curr_xyz = curr_waypoint['xyz']
|
||||||
|
curr_quat = curr_waypoint['quat']
|
||||||
|
curr_grip = curr_waypoint['gripper']
|
||||||
|
next_xyz = next_waypoint['xyz']
|
||||||
|
next_quat = next_waypoint['quat']
|
||||||
|
next_grip = next_waypoint['gripper']
|
||||||
|
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
|
||||||
|
quat = curr_quat + (next_quat - curr_quat) * t_frac
|
||||||
|
gripper = curr_grip + (next_grip - curr_grip) * t_frac
|
||||||
|
return xyz, quat, gripper
|
||||||
|
|
||||||
|
def __call__(self, ts):
|
||||||
|
# generate trajectory at first timestep, then open-loop execution
|
||||||
|
if self.step_count == 0:
|
||||||
|
self.generate_trajectory(ts)
|
||||||
|
|
||||||
|
# obtain left and right waypoints
|
||||||
|
if self.left_trajectory[0]['t'] == self.step_count:
|
||||||
|
self.curr_left_waypoint = self.left_trajectory.pop(0)
|
||||||
|
next_left_waypoint = self.left_trajectory[0]
|
||||||
|
|
||||||
|
if self.right_trajectory[0]['t'] == self.step_count:
|
||||||
|
self.curr_right_waypoint = self.right_trajectory.pop(0)
|
||||||
|
next_right_waypoint = self.right_trajectory[0]
|
||||||
|
|
||||||
|
# interpolate between waypoints to obtain current pose and gripper command
|
||||||
|
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
|
||||||
|
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
|
||||||
|
|
||||||
|
# Inject noise
|
||||||
|
if self.inject_noise:
|
||||||
|
scale = 0.01
|
||||||
|
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
|
||||||
|
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
|
||||||
|
|
||||||
|
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
|
||||||
|
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
|
||||||
|
|
||||||
|
self.step_count += 1
|
||||||
|
return np.concatenate([action_left, action_right])
|
||||||
|
|
||||||
|
|
||||||
|
class PickAndTransferPolicy(BasePolicy):
|
||||||
|
|
||||||
|
def generate_trajectory(self, ts_first):
|
||||||
|
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||||
|
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||||
|
|
||||||
|
box_info = np.array(ts_first.observation['env_state'])
|
||||||
|
box_xyz = box_info[:3]
|
||||||
|
box_quat = box_info[3:]
|
||||||
|
# print(f"Generate trajectory for {box_xyz=}")
|
||||||
|
|
||||||
|
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
|
||||||
|
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||||
|
|
||||||
|
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
|
||||||
|
|
||||||
|
meet_xyz = np.array([0, 0.5, 0.25])
|
||||||
|
|
||||||
|
self.left_trajectory = [
|
||||||
|
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||||
|
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
|
||||||
|
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
|
||||||
|
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
|
||||||
|
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
|
||||||
|
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
|
||||||
|
]
|
||||||
|
|
||||||
|
self.right_trajectory = [
|
||||||
|
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||||
|
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
|
||||||
|
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
|
||||||
|
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
|
||||||
|
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
|
||||||
|
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
|
||||||
|
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
|
||||||
|
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
|
||||||
|
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class InsertionPolicy(BasePolicy):
|
||||||
|
|
||||||
|
def generate_trajectory(self, ts_first):
|
||||||
|
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
|
||||||
|
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
|
||||||
|
|
||||||
|
peg_info = np.array(ts_first.observation['env_state'])[:7]
|
||||||
|
peg_xyz = peg_info[:3]
|
||||||
|
peg_quat = peg_info[3:]
|
||||||
|
|
||||||
|
socket_info = np.array(ts_first.observation['env_state'])[7:]
|
||||||
|
socket_xyz = socket_info[:3]
|
||||||
|
socket_quat = socket_info[3:]
|
||||||
|
|
||||||
|
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
|
||||||
|
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
|
||||||
|
|
||||||
|
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
|
||||||
|
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
|
||||||
|
|
||||||
|
meet_xyz = np.array([0, 0.5, 0.15])
|
||||||
|
lift_right = 0.00715
|
||||||
|
|
||||||
|
self.left_trajectory = [
|
||||||
|
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
|
||||||
|
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
|
||||||
|
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
|
||||||
|
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
|
||||||
|
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
|
||||||
|
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
|
||||||
|
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
|
||||||
|
]
|
||||||
|
|
||||||
|
self.right_trajectory = [
|
||||||
|
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
|
||||||
|
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
|
||||||
|
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
|
||||||
|
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
|
||||||
|
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
|
||||||
|
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||||
|
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_policy(task_name):
|
||||||
|
# example rolling out pick_and_transfer policy
|
||||||
|
onscreen_render = True
|
||||||
|
inject_noise = False
|
||||||
|
|
||||||
|
# setup the environment
|
||||||
|
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
|
||||||
|
if 'sim_transfer_cube' in task_name:
|
||||||
|
env = make_ee_sim_env('sim_transfer_cube')
|
||||||
|
elif 'sim_insertion' in task_name:
|
||||||
|
env = make_ee_sim_env('sim_insertion')
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
for episode_idx in range(2):
|
||||||
|
ts = env.reset()
|
||||||
|
episode = [ts]
|
||||||
|
if onscreen_render:
|
||||||
|
ax = plt.subplot()
|
||||||
|
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||||
|
plt.ion()
|
||||||
|
|
||||||
|
policy = PickAndTransferPolicy(inject_noise)
|
||||||
|
for step in range(episode_len):
|
||||||
|
action = policy(ts)
|
||||||
|
ts = env.step(action)
|
||||||
|
episode.append(ts)
|
||||||
|
if onscreen_render:
|
||||||
|
plt_img.set_data(ts.observation['images']['angle'])
|
||||||
|
plt.pause(0.02)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
episode_return = np.sum([ts.reward for ts in episode[1:]])
|
||||||
|
if episode_return > 0:
|
||||||
|
print(f"{episode_idx=} Successful, {episode_return=}")
|
||||||
|
else:
|
||||||
|
print(f"{episode_idx=} Failed")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_task_name = 'sim_transfer_cube_scripted'
|
||||||
|
test_policy(test_task_name)
|
||||||
|
|
||||||
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
278
realman_src/realman_aloha/shadow_rm_act/sim_env.py
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import collections
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from dm_control import mujoco
|
||||||
|
from dm_control.rl import control
|
||||||
|
from dm_control.suite import base
|
||||||
|
|
||||||
|
from constants import DT, XML_DIR, START_ARM_POSE
|
||||||
|
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
|
||||||
|
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
|
||||||
|
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
|
||||||
|
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
|
||||||
|
|
||||||
|
import IPython
|
||||||
|
e = IPython.embed
|
||||||
|
|
||||||
|
BOX_POSE = [None] # to be changed from outside
|
||||||
|
|
||||||
|
def make_sim_env(task_name):
|
||||||
|
"""
|
||||||
|
Environment for simulated robot bi-manual manipulation, with joint position control
|
||||||
|
Action space: [left_arm_qpos (6), # absolute joint position
|
||||||
|
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_qpos (6), # absolute joint position
|
||||||
|
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
|
||||||
|
|
||||||
|
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
|
||||||
|
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
|
||||||
|
right_arm_qpos (6), # absolute joint position
|
||||||
|
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
|
||||||
|
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
right_arm_qvel (6), # absolute joint velocity (rad)
|
||||||
|
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
|
||||||
|
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
|
||||||
|
"""
|
||||||
|
if 'sim_transfer_cube' in task_name:
|
||||||
|
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
|
||||||
|
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||||
|
task = TransferCubeTask(random=False)
|
||||||
|
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||||
|
n_sub_steps=None, flat_observation=False)
|
||||||
|
elif 'sim_insertion' in task_name:
|
||||||
|
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
|
||||||
|
physics = mujoco.Physics.from_xml_path(xml_path)
|
||||||
|
task = InsertionTask(random=False)
|
||||||
|
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
|
||||||
|
n_sub_steps=None, flat_observation=False)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return env
|
||||||
|
|
||||||
|
class BimanualViperXTask(base.Task):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
|
||||||
|
def before_step(self, action, physics):
|
||||||
|
left_arm_action = action[:6]
|
||||||
|
right_arm_action = action[7:7+6]
|
||||||
|
normalized_left_gripper_action = action[6]
|
||||||
|
normalized_right_gripper_action = action[7+6]
|
||||||
|
|
||||||
|
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
|
||||||
|
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
|
||||||
|
|
||||||
|
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
|
||||||
|
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
|
||||||
|
|
||||||
|
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
|
||||||
|
super().before_step(env_action, physics)
|
||||||
|
return
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qpos(physics):
|
||||||
|
qpos_raw = physics.data.qpos.copy()
|
||||||
|
left_qpos_raw = qpos_raw[:8]
|
||||||
|
right_qpos_raw = qpos_raw[8:16]
|
||||||
|
left_arm_qpos = left_qpos_raw[:6]
|
||||||
|
right_arm_qpos = right_qpos_raw[:6]
|
||||||
|
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
|
||||||
|
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_qvel(physics):
|
||||||
|
qvel_raw = physics.data.qvel.copy()
|
||||||
|
left_qvel_raw = qvel_raw[:8]
|
||||||
|
right_qvel_raw = qvel_raw[8:16]
|
||||||
|
left_arm_qvel = left_qvel_raw[:6]
|
||||||
|
right_arm_qvel = right_qvel_raw[:6]
|
||||||
|
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
|
||||||
|
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
|
||||||
|
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_observation(self, physics):
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs['qpos'] = self.get_qpos(physics)
|
||||||
|
obs['qvel'] = self.get_qvel(physics)
|
||||||
|
obs['env_state'] = self.get_env_state(physics)
|
||||||
|
obs['images'] = dict()
|
||||||
|
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
|
||||||
|
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
|
||||||
|
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
|
||||||
|
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether left gripper is holding the box
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TransferCubeTask(BimanualViperXTask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||||
|
# reset qpos, control and box position
|
||||||
|
with physics.reset_context():
|
||||||
|
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||||
|
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||||
|
assert BOX_POSE[0] is not None
|
||||||
|
physics.named.data.qpos[-7:] = BOX_POSE[0]
|
||||||
|
# print(f"{BOX_POSE=}")
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether left gripper is holding the box
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_table = ("red_box", "table") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_right_gripper:
|
||||||
|
reward = 1
|
||||||
|
if touch_right_gripper and not touch_table: # lifted
|
||||||
|
reward = 2
|
||||||
|
if touch_left_gripper: # attempted transfer
|
||||||
|
reward = 3
|
||||||
|
if touch_left_gripper and not touch_table: # successful transfer
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
class InsertionTask(BimanualViperXTask):
|
||||||
|
def __init__(self, random=None):
|
||||||
|
super().__init__(random=random)
|
||||||
|
self.max_reward = 4
|
||||||
|
|
||||||
|
def initialize_episode(self, physics):
|
||||||
|
"""Sets the state of the environment at the start of each episode."""
|
||||||
|
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
|
||||||
|
# reset qpos, control and box position
|
||||||
|
with physics.reset_context():
|
||||||
|
physics.named.data.qpos[:16] = START_ARM_POSE
|
||||||
|
np.copyto(physics.data.ctrl, START_ARM_POSE)
|
||||||
|
assert BOX_POSE[0] is not None
|
||||||
|
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
|
||||||
|
# print(f"{BOX_POSE=}")
|
||||||
|
super().initialize_episode(physics)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_env_state(physics):
|
||||||
|
env_state = physics.data.qpos.copy()[16:]
|
||||||
|
return env_state
|
||||||
|
|
||||||
|
def get_reward(self, physics):
|
||||||
|
# return whether peg touches the pin
|
||||||
|
all_contact_pairs = []
|
||||||
|
for i_contact in range(physics.data.ncon):
|
||||||
|
id_geom_1 = physics.data.contact[i_contact].geom1
|
||||||
|
id_geom_2 = physics.data.contact[i_contact].geom2
|
||||||
|
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
|
||||||
|
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
|
||||||
|
contact_pair = (name_geom_1, name_geom_2)
|
||||||
|
all_contact_pairs.append(contact_pair)
|
||||||
|
|
||||||
|
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
|
||||||
|
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||||
|
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||||
|
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
|
||||||
|
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
|
||||||
|
|
||||||
|
peg_touch_table = ("red_peg", "table") in all_contact_pairs
|
||||||
|
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
|
||||||
|
("socket-2", "table") in all_contact_pairs or \
|
||||||
|
("socket-3", "table") in all_contact_pairs or \
|
||||||
|
("socket-4", "table") in all_contact_pairs
|
||||||
|
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
|
||||||
|
("red_peg", "socket-2") in all_contact_pairs or \
|
||||||
|
("red_peg", "socket-3") in all_contact_pairs or \
|
||||||
|
("red_peg", "socket-4") in all_contact_pairs
|
||||||
|
pin_touched = ("red_peg", "pin") in all_contact_pairs
|
||||||
|
|
||||||
|
reward = 0
|
||||||
|
if touch_left_gripper and touch_right_gripper: # touch both
|
||||||
|
reward = 1
|
||||||
|
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
|
||||||
|
reward = 2
|
||||||
|
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
|
||||||
|
reward = 3
|
||||||
|
if pin_touched: # successful insertion
|
||||||
|
reward = 4
|
||||||
|
return reward
|
||||||
|
|
||||||
|
|
||||||
|
def get_action(master_bot_left, master_bot_right):
|
||||||
|
action = np.zeros(14)
|
||||||
|
# arm action
|
||||||
|
action[:6] = master_bot_left.dxl.joint_states.position[:6]
|
||||||
|
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
|
||||||
|
# gripper action
|
||||||
|
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
|
||||||
|
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
|
||||||
|
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
|
||||||
|
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
|
||||||
|
action[6] = normalized_left_pos
|
||||||
|
action[7+6] = normalized_right_pos
|
||||||
|
return action
|
||||||
|
|
||||||
|
def test_sim_teleop():
|
||||||
|
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
|
||||||
|
from interbotix_xs_modules.arm import InterbotixManipulatorXS
|
||||||
|
|
||||||
|
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
|
||||||
|
|
||||||
|
# source of data
|
||||||
|
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||||
|
robot_name=f'master_left', init_node=True)
|
||||||
|
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
|
||||||
|
robot_name=f'master_right', init_node=False)
|
||||||
|
|
||||||
|
# setup the environment
|
||||||
|
env = make_sim_env('sim_transfer_cube')
|
||||||
|
ts = env.reset()
|
||||||
|
episode = [ts]
|
||||||
|
# setup plotting
|
||||||
|
ax = plt.subplot()
|
||||||
|
plt_img = ax.imshow(ts.observation['images']['angle'])
|
||||||
|
plt.ion()
|
||||||
|
|
||||||
|
for t in range(1000):
|
||||||
|
action = get_action(master_bot_left, master_bot_right)
|
||||||
|
ts = env.step(action)
|
||||||
|
episode.append(ts)
|
||||||
|
|
||||||
|
plt_img.set_data(ts.observation['images']['angle'])
|
||||||
|
plt.pause(0.02)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_sim_teleop()
|
||||||
|
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1,575 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import dm_env
|
||||||
|
import logging
|
||||||
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
import tracemalloc
|
||||||
|
from einops import rearrange
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torchvision import transforms
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
from shadow_camera.realsense import RealSenseCamera
|
||||||
|
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||||
|
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||||
|
from shadow_act.utils.utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
|
# 配置logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
# # 隐藏h5py的警告Creating converter from 7 to 5
|
||||||
|
# logging.getLogger("h5py").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
class RmActEvaluator:
|
||||||
|
def __init__(self, config, save_episode=True, num_rollouts=50):
|
||||||
|
"""
|
||||||
|
初始化Evaluator类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): 配置字典
|
||||||
|
checkpoint_name (str): 检查点名称
|
||||||
|
save_episode (bool): 是否保存每个episode
|
||||||
|
num_rollouts (int): 滚动次数
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self._seed = config["seed"]
|
||||||
|
self.robot_env = config["robot_env"]
|
||||||
|
self.checkpoint_dir = config["checkpoint_dir"]
|
||||||
|
self.checkpoint_name = config["checkpoint_name"]
|
||||||
|
self.save_episode = save_episode
|
||||||
|
self.num_rollouts = num_rollouts
|
||||||
|
self.state_dim = config["state_dim"]
|
||||||
|
self.real_robot = config["real_robot"]
|
||||||
|
self.policy_class = config["policy_class"]
|
||||||
|
self.onscreen_render = config["onscreen_render"]
|
||||||
|
self.camera_names = config["camera_names"]
|
||||||
|
self.max_timesteps = config["episode_len"]
|
||||||
|
self.task_name = config["task_name"]
|
||||||
|
self.temporal_agg = config["temporal_agg"]
|
||||||
|
self.onscreen_cam = "angle"
|
||||||
|
self.policy_config = config["policy_config"]
|
||||||
|
self.vq = config["policy_config"]["vq"]
|
||||||
|
# self.actuator_config = config["actuator_config"]
|
||||||
|
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
|
||||||
|
self.stats = None
|
||||||
|
self.env = None
|
||||||
|
self.env_max_reward = 0
|
||||||
|
|
||||||
|
def _make_policy(self, policy_class, policy_config):
|
||||||
|
"""
|
||||||
|
根据策略类和配置创建策略对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
policy_class (str): 策略类名称
|
||||||
|
policy_config (dict): 策略配置字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
policy: 创建的策略对象
|
||||||
|
"""
|
||||||
|
if policy_class == "ACT":
|
||||||
|
return ACTPolicy(policy_config)
|
||||||
|
elif policy_class == "CNNMLP":
|
||||||
|
return CNNMLPPolicy(policy_config)
|
||||||
|
elif policy_class == "Diffusion":
|
||||||
|
return DiffusionPolicy(policy_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
|
||||||
|
|
||||||
|
def load_policy_and_stats(self):
|
||||||
|
"""
|
||||||
|
加载策略和统计数据
|
||||||
|
"""
|
||||||
|
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
|
||||||
|
logging.info(f"Loading policy from: {checkpoint_path}")
|
||||||
|
self.policy = self._make_policy(self.policy_class, self.policy_config)
|
||||||
|
# 加载模型并设置为评估模式
|
||||||
|
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
|
||||||
|
self.policy.cuda()
|
||||||
|
self.policy.eval()
|
||||||
|
|
||||||
|
if self.vq:
|
||||||
|
vq_dim = self.config["policy_config"]["vq_dim"]
|
||||||
|
vq_class = self.config["policy_config"]["vq_class"]
|
||||||
|
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
|
||||||
|
latent_model_checkpoint_path = os.path.join(
|
||||||
|
self.checkpoint_dir, "latent_model_last.ckpt"
|
||||||
|
)
|
||||||
|
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
|
||||||
|
self.latent_model.eval()
|
||||||
|
self.latent_model.cuda()
|
||||||
|
logging.info(
|
||||||
|
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logging.info(f"Loaded: {checkpoint_path}")
|
||||||
|
|
||||||
|
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
|
||||||
|
with open(stats_path, "rb") as f:
|
||||||
|
self.stats = pickle.load(f)
|
||||||
|
|
||||||
|
def pre_process(self, state_qpos):
|
||||||
|
"""
|
||||||
|
预处理状态位置
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_qpos (np.array): 状态位置数组
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 预处理后的状态位置
|
||||||
|
"""
|
||||||
|
if self.policy_class == "Diffusion":
|
||||||
|
return ((state_qpos + 1) / 2) * (
|
||||||
|
self.stats["action_max"] - self.stats["action_min"]
|
||||||
|
) + self.stats["action_min"]
|
||||||
|
# 标准化处理,均值为 0,标准差为 1
|
||||||
|
|
||||||
|
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
|
||||||
|
|
||||||
|
def post_process(self, action):
|
||||||
|
"""
|
||||||
|
后处理动作
|
||||||
|
|
||||||
|
Args:
|
||||||
|
action (np.array): 动作数组
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 后处理后的动作
|
||||||
|
"""
|
||||||
|
# 反标准化处理
|
||||||
|
return action * self.stats["action_std"] + self.stats["action_mean"]
|
||||||
|
|
||||||
|
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
|
||||||
|
"""
|
||||||
|
获取图像
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timestep (object): 时间步对象
|
||||||
|
camera_names (list): 相机名称列表
|
||||||
|
random_crop_resize (bool): 是否随机裁剪和调整大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
|
||||||
|
"""
|
||||||
|
current_images = []
|
||||||
|
for cam_name in camera_names:
|
||||||
|
current_image = rearrange(
|
||||||
|
timestep.observation["images"][cam_name], "h w c -> c h w"
|
||||||
|
)
|
||||||
|
current_images.append(current_image)
|
||||||
|
current_image = np.stack(current_images, axis=0)
|
||||||
|
current_image = (
|
||||||
|
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if random_crop_resize:
|
||||||
|
logging.info("Random crop resize is used!")
|
||||||
|
original_size = current_image.shape[-2:]
|
||||||
|
ratio = 0.95
|
||||||
|
current_image = current_image[
|
||||||
|
...,
|
||||||
|
int(original_size[0] * (1 - ratio) / 2) : int(
|
||||||
|
original_size[0] * (1 + ratio) / 2
|
||||||
|
),
|
||||||
|
int(original_size[1] * (1 - ratio) / 2) : int(
|
||||||
|
original_size[1] * (1 + ratio) / 2
|
||||||
|
),
|
||||||
|
]
|
||||||
|
current_image = current_image.squeeze(0)
|
||||||
|
resize_transform = transforms.Resize(original_size, antialias=True)
|
||||||
|
current_image = resize_transform(current_image)
|
||||||
|
current_image = current_image.unsqueeze(0)
|
||||||
|
|
||||||
|
return current_image
|
||||||
|
|
||||||
|
def load_environment(self):
|
||||||
|
"""
|
||||||
|
加载环境
|
||||||
|
"""
|
||||||
|
if self.real_robot:
|
||||||
|
self.env = DeviceAloha(self.robot_env)
|
||||||
|
self.env_max_reward = 0
|
||||||
|
else:
|
||||||
|
from sim_env import make_sim_env
|
||||||
|
|
||||||
|
self.env = make_sim_env(self.task_name)
|
||||||
|
self.env_max_reward = self.env.task.max_reward
|
||||||
|
|
||||||
|
def get_auto_index(self, checkpoint_dir):
|
||||||
|
max_idx = 1000
|
||||||
|
for i in range(max_idx + 1):
|
||||||
|
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
|
||||||
|
return i
|
||||||
|
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
|
||||||
|
|
||||||
|
def evaluate(self, checkpoint_name=None):
|
||||||
|
"""
|
||||||
|
评估策略
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: 成功率和平均回报
|
||||||
|
"""
|
||||||
|
if checkpoint_name is not None:
|
||||||
|
self.checkpoint_name = checkpoint_name
|
||||||
|
set_seed(self._seed) # np与torch的随机种子
|
||||||
|
self.load_policy_and_stats()
|
||||||
|
self.load_environment()
|
||||||
|
|
||||||
|
query_frequency = self.policy_config["num_queries"]
|
||||||
|
|
||||||
|
# 时间聚合时,每个时间步只有1个查询
|
||||||
|
if self.temporal_agg:
|
||||||
|
query_frequency = 1
|
||||||
|
num_queries = self.policy_config["num_queries"]
|
||||||
|
|
||||||
|
# # 真实机器人时,基础延迟为13???
|
||||||
|
# if self.real_robot:
|
||||||
|
# BASE_DELAY = 13
|
||||||
|
# # query_frequency -= BASE_DELAY
|
||||||
|
|
||||||
|
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
|
||||||
|
episode_returns = []
|
||||||
|
highest_rewards = []
|
||||||
|
|
||||||
|
for rollout_id in range(self.num_rollouts):
|
||||||
|
|
||||||
|
timestep = self.env.reset()
|
||||||
|
|
||||||
|
if self.onscreen_render:
|
||||||
|
# TODO 画图
|
||||||
|
pass
|
||||||
|
if self.temporal_agg:
|
||||||
|
all_time_actions = torch.zeros(
|
||||||
|
[max_timesteps, max_timesteps + num_queries, self.state_dim]
|
||||||
|
).cuda()
|
||||||
|
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
|
||||||
|
rewards = []
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
time_0 = time.time()
|
||||||
|
DT = 1 / 30
|
||||||
|
culmulated_delay = 0
|
||||||
|
for t in range(max_timesteps):
|
||||||
|
time_1 = time.time()
|
||||||
|
if self.onscreen_render:
|
||||||
|
# TODO 显示图像
|
||||||
|
pass
|
||||||
|
# process previous timestep to get qpos and image_list
|
||||||
|
obs = timestep.observation
|
||||||
|
qpos_numpy = np.array(obs["qpos"])
|
||||||
|
qpos_history_raw[t] = qpos_numpy
|
||||||
|
qpos = self.pre_process(qpos_numpy)
|
||||||
|
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||||
|
|
||||||
|
logging.info(f"t{t}")
|
||||||
|
|
||||||
|
if t % query_frequency == 0:
|
||||||
|
current_image = self.get_image_torch(
|
||||||
|
timestep,
|
||||||
|
self.camera_names,
|
||||||
|
random_crop_resize=(
|
||||||
|
self.config["policy_class"] == "Diffusion"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if t == 0:
|
||||||
|
# 网络预热
|
||||||
|
for _ in range(10):
|
||||||
|
self.policy(qpos, current_image)
|
||||||
|
logging.info("Network warm up done")
|
||||||
|
|
||||||
|
if self.config["policy_class"] == "ACT":
|
||||||
|
if t % query_frequency == 0:
|
||||||
|
if self.vq:
|
||||||
|
if rollout_id == 0:
|
||||||
|
for _ in range(10):
|
||||||
|
vq_sample = self.latent_model.generate(
|
||||||
|
1, temperature=1, x=None
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
torch.nonzero(vq_sample[0])[:, 1]
|
||||||
|
.cpu()
|
||||||
|
.numpy()
|
||||||
|
)
|
||||||
|
vq_sample = self.latent_model.generate(
|
||||||
|
1, temperature=1, x=None
|
||||||
|
)
|
||||||
|
all_actions = self.policy(
|
||||||
|
qpos, current_image, vq_sample=vq_sample
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
all_actions = self.policy(qpos, current_image)
|
||||||
|
# if self.real_robot:
|
||||||
|
# all_actions = torch.cat(
|
||||||
|
# [
|
||||||
|
# all_actions[:, :-BASE_DELAY, :-2],
|
||||||
|
# all_actions[:, BASE_DELAY:, -2:],
|
||||||
|
# ],
|
||||||
|
# dim=2,
|
||||||
|
# )
|
||||||
|
if self.temporal_agg:
|
||||||
|
all_time_actions[[t], t : t + num_queries] = all_actions
|
||||||
|
actions_for_curr_step = all_time_actions[:, t]
|
||||||
|
actions_populated = torch.all(
|
||||||
|
actions_for_curr_step != 0, axis=1
|
||||||
|
)
|
||||||
|
actions_for_curr_step = actions_for_curr_step[
|
||||||
|
actions_populated
|
||||||
|
]
|
||||||
|
k = 0.01
|
||||||
|
exp_weights = np.exp(
|
||||||
|
-k * np.arange(len(actions_for_curr_step))
|
||||||
|
)
|
||||||
|
exp_weights = exp_weights / exp_weights.sum()
|
||||||
|
exp_weights = (
|
||||||
|
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
|
||||||
|
)
|
||||||
|
raw_action = (actions_for_curr_step * exp_weights).sum(
|
||||||
|
dim=0, keepdim=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_action = all_actions[:, t % query_frequency]
|
||||||
|
elif self.config["policy_class"] == "Diffusion":
|
||||||
|
if t % query_frequency == 0:
|
||||||
|
all_actions = self.policy(qpos, current_image)
|
||||||
|
# if self.real_robot:
|
||||||
|
# all_actions = torch.cat(
|
||||||
|
# [
|
||||||
|
# all_actions[:, :-BASE_DELAY, :-2],
|
||||||
|
# all_actions[:, BASE_DELAY:, -2:],
|
||||||
|
# ],
|
||||||
|
# dim=2,
|
||||||
|
# )
|
||||||
|
raw_action = all_actions[:, t % query_frequency]
|
||||||
|
elif self.config["policy_class"] == "CNNMLP":
|
||||||
|
raw_action = self.policy(qpos, current_image)
|
||||||
|
all_actions = raw_action.unsqueeze(0)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
### post-process actions
|
||||||
|
raw_action = raw_action.squeeze(0).cpu().numpy()
|
||||||
|
action = self.post_process(raw_action)
|
||||||
|
|
||||||
|
### step the environment
|
||||||
|
if self.real_robot:
|
||||||
|
logging.info(f" action = {action}")
|
||||||
|
timestep = self.env.step(action)
|
||||||
|
|
||||||
|
rewards.append(timestep.reward)
|
||||||
|
duration = time.time() - time_1
|
||||||
|
sleep_time = max(0, DT - duration)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
if duration >= DT:
|
||||||
|
culmulated_delay += duration - DT
|
||||||
|
logging.warning(
|
||||||
|
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
if self.real_robot:
|
||||||
|
log_id = self.get_auto_index(self.checkpoint_dir)
|
||||||
|
np.save(
|
||||||
|
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
|
||||||
|
qpos_history_raw,
|
||||||
|
)
|
||||||
|
plt.figure(figsize=(10, 20))
|
||||||
|
for i in range(self.state_dim):
|
||||||
|
plt.subplot(self.state_dim, 1, i + 1)
|
||||||
|
plt.plot(qpos_history_raw[:, i])
|
||||||
|
if i != self.state_dim - 1:
|
||||||
|
plt.xticks([])
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
rewards = np.array(rewards)
|
||||||
|
episode_return = np.sum(rewards[rewards != None])
|
||||||
|
episode_returns.append(episode_return)
|
||||||
|
episode_highest_reward = np.max(rewards)
|
||||||
|
highest_rewards.append(episode_highest_reward)
|
||||||
|
logging.info(
|
||||||
|
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
|
||||||
|
)
|
||||||
|
|
||||||
|
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
|
||||||
|
avg_return = np.mean(episode_returns)
|
||||||
|
summary_str = (
|
||||||
|
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
|
||||||
|
)
|
||||||
|
for r in range(self.env_max_reward + 1):
|
||||||
|
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
|
||||||
|
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
|
||||||
|
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
|
||||||
|
|
||||||
|
logging.info(summary_str)
|
||||||
|
|
||||||
|
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
|
||||||
|
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
|
||||||
|
f.write(summary_str)
|
||||||
|
f.write(repr(episode_returns))
|
||||||
|
f.write("\n\n")
|
||||||
|
f.write(repr(highest_rewards))
|
||||||
|
|
||||||
|
return success_rate, avg_return
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAloha:
|
||||||
|
def __init__(self, aloha_config):
|
||||||
|
"""
|
||||||
|
初始化设备
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_name (str): 设备名称
|
||||||
|
"""
|
||||||
|
config_left_arm = aloha_config["rm_left_arm"]
|
||||||
|
config_right_arm = aloha_config["rm_right_arm"]
|
||||||
|
config_head_camera = aloha_config["head_camera"]
|
||||||
|
config_bottom_camera = aloha_config["bottom_camera"]
|
||||||
|
config_left_camera = aloha_config["left_camera"]
|
||||||
|
config_right_camera = aloha_config["right_camera"]
|
||||||
|
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||||
|
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||||
|
self.arm_axis = aloha_config["arm_axis"]
|
||||||
|
self.arm_left = RmArm(config_left_arm)
|
||||||
|
self.arm_right = RmArm(config_right_arm)
|
||||||
|
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||||
|
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||||
|
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||||
|
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||||
|
self.camera_left.start_camera()
|
||||||
|
self.camera_right.start_camera()
|
||||||
|
self.camera_bottom.start_camera()
|
||||||
|
self.camera_top.start_camera()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
关闭摄像头
|
||||||
|
"""
|
||||||
|
self.camera_left.close()
|
||||||
|
self.camera_right.close()
|
||||||
|
self.camera_bottom.close()
|
||||||
|
self.camera_top.close()
|
||||||
|
|
||||||
|
def get_qps(self):
|
||||||
|
"""
|
||||||
|
获取关节角度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 关节角度
|
||||||
|
"""
|
||||||
|
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||||
|
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||||
|
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||||
|
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||||
|
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||||
|
|
||||||
|
def get_qvel(self):
|
||||||
|
"""
|
||||||
|
获取关节速度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 关节速度
|
||||||
|
"""
|
||||||
|
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||||
|
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||||
|
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||||
|
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||||
|
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||||
|
|
||||||
|
def get_effort(self):
|
||||||
|
"""
|
||||||
|
获取关节力
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 关节力
|
||||||
|
"""
|
||||||
|
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||||
|
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||||
|
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||||
|
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||||
|
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||||
|
|
||||||
|
def get_images(self):
|
||||||
|
"""
|
||||||
|
获取图像
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 图像字典
|
||||||
|
"""
|
||||||
|
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||||
|
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
|
||||||
|
True, False, False, False
|
||||||
|
)
|
||||||
|
self.left_image, _, _, _ = self.camera_left.read_frame(
|
||||||
|
True, False, False, False
|
||||||
|
)
|
||||||
|
self.right_image, _, _, _ = self.camera_right.read_frame(
|
||||||
|
True, False, False, False
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"cam_high": self.top_image,
|
||||||
|
"cam_low": self.bottom_image,
|
||||||
|
"cam_left": self.left_image,
|
||||||
|
"cam_right": self.right_image,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_observation(self):
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs["qpos"] = self.get_qps()
|
||||||
|
obs["qvel"] = self.get_qvel()
|
||||||
|
obs["effort"] = self.get_effort()
|
||||||
|
obs["images"] = self.get_images()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
logging.info("Resetting the environment")
|
||||||
|
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
|
||||||
|
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
|
||||||
|
self.arm_left.set_gripper_position(0)
|
||||||
|
self.arm_right.set_gripper_position(0)
|
||||||
|
return dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.FIRST,
|
||||||
|
reward=0,
|
||||||
|
discount=None,
|
||||||
|
observation=self.get_observation(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def step(self, target_angle):
|
||||||
|
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
|
||||||
|
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
|
||||||
|
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
|
||||||
|
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
|
||||||
|
return dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.MID,
|
||||||
|
reward=0,
|
||||||
|
discount=None,
|
||||||
|
observation=self.get_observation(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||||
|
# config = yaml.safe_load(f)
|
||||||
|
# aloha_config = config["robot_env"]
|
||||||
|
# device = DeviceAloha(aloha_config)
|
||||||
|
# device.reset()
|
||||||
|
# while True:
|
||||||
|
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||||
|
# time_step = time.time()
|
||||||
|
# timestep = device.step(init_angle)
|
||||||
|
# logging.info(f"Time: {time.time() - time_step}")
|
||||||
|
# obs = timestep.observation
|
||||||
|
|
||||||
|
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
# logging.info(f"Config: {config}")
|
||||||
|
evaluator = RmActEvaluator(config)
|
||||||
|
success_rate, avg_return = evaluator.evaluate()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1,153 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Backbone modules.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from torch import nn
|
||||||
|
from typing import Dict, List
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .position_encoding import build_position_encoding
|
||||||
|
from torchvision.models import ResNet18_Weights
|
||||||
|
from torchvision.models._utils import IntermediateLayerGetter
|
||||||
|
from shadow_act.utils.misc import NestedTensor, is_main_process
|
||||||
|
|
||||||
|
|
||||||
|
class FrozenBatchNorm2d(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||||
|
|
||||||
|
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
||||||
|
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
|
||||||
|
produce nans.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, n):
|
||||||
|
super(FrozenBatchNorm2d, self).__init__()
|
||||||
|
self.register_buffer("weight", torch.ones(n))
|
||||||
|
self.register_buffer("bias", torch.zeros(n))
|
||||||
|
self.register_buffer("running_mean", torch.zeros(n))
|
||||||
|
self.register_buffer("running_var", torch.ones(n))
|
||||||
|
|
||||||
|
def _load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
):
|
||||||
|
num_batches_tracked_key = prefix + "num_batches_tracked"
|
||||||
|
if num_batches_tracked_key in state_dict:
|
||||||
|
del state_dict[num_batches_tracked_key]
|
||||||
|
|
||||||
|
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
strict,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
error_msgs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# move reshapes to the beginning
|
||||||
|
# to make it fuser-friendly
|
||||||
|
w = self.weight.reshape(1, -1, 1, 1)
|
||||||
|
b = self.bias.reshape(1, -1, 1, 1)
|
||||||
|
rv = self.running_var.reshape(1, -1, 1, 1)
|
||||||
|
rm = self.running_mean.reshape(1, -1, 1, 1)
|
||||||
|
eps = 1e-5
|
||||||
|
scale = w * (rv + eps).rsqrt()
|
||||||
|
bias = b - rm * scale
|
||||||
|
return x * scale + bias
|
||||||
|
|
||||||
|
|
||||||
|
class BackboneBase(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbone: nn.Module,
|
||||||
|
train_backbone: bool,
|
||||||
|
num_channels: int,
|
||||||
|
return_interm_layers: bool,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
|
||||||
|
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
||||||
|
# parameter.requires_grad_(False)
|
||||||
|
if return_interm_layers:
|
||||||
|
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
||||||
|
else:
|
||||||
|
return_layers = {"layer4": "0"}
|
||||||
|
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
||||||
|
self.num_channels = num_channels
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
xs = self.body(tensor)
|
||||||
|
return xs
|
||||||
|
# out: Dict[str, NestedTensor] = {}
|
||||||
|
# for name, x in xs.items():
|
||||||
|
# m = tensor_list.mask
|
||||||
|
# assert m is not None
|
||||||
|
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
||||||
|
# out[name] = NestedTensor(x, mask)
|
||||||
|
# return out
|
||||||
|
|
||||||
|
|
||||||
|
class Backbone(BackboneBase):
|
||||||
|
"""ResNet backbone with frozen BatchNorm."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
train_backbone: bool,
|
||||||
|
return_interm_layers: bool,
|
||||||
|
dilation: bool,
|
||||||
|
):
|
||||||
|
backbone = getattr(torchvision.models, name)(
|
||||||
|
replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
|
||||||
|
norm_layer=FrozenBatchNorm2d,
|
||||||
|
)
|
||||||
|
# backbone = getattr(torchvision.models, name)(
|
||||||
|
# replace_stride_with_dilation=[False, False, dilation],
|
||||||
|
# pretrained=is_main_process(),
|
||||||
|
# norm_layer=FrozenBatchNorm2d,
|
||||||
|
# ) # pretrained # TODO do we want frozen batch_norm??
|
||||||
|
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
|
||||||
|
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
|
||||||
|
|
||||||
|
|
||||||
|
class Joiner(nn.Sequential):
|
||||||
|
def __init__(self, backbone, position_embedding):
|
||||||
|
super().__init__(backbone, position_embedding)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
xs = self[0](tensor_list)
|
||||||
|
out: List[NestedTensor] = []
|
||||||
|
pos = []
|
||||||
|
for name, x in xs.items():
|
||||||
|
out.append(x)
|
||||||
|
# position encoding
|
||||||
|
pos.append(self[1](x).to(x.dtype))
|
||||||
|
|
||||||
|
return out, pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_backbone(
|
||||||
|
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||||
|
):
|
||||||
|
|
||||||
|
position_embedding = build_position_encoding(
|
||||||
|
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
|
||||||
|
)
|
||||||
|
train_backbone = lr_backbone > 0
|
||||||
|
return_interm_layers = masks
|
||||||
|
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
|
||||||
|
model = Joiner(backbone, position_embedding)
|
||||||
|
model.num_channels = backbone.num_channels
|
||||||
|
return model
|
||||||
@@ -0,0 +1,436 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
DETR model and criterion classes.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from shadow_act.models.transformer import Transformer
|
||||||
|
from .backbone import build_backbone
|
||||||
|
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def reparametrize(mu, logvar):
|
||||||
|
std = logvar.div(2).exp()
|
||||||
|
eps = Variable(std.data.new(std.size()).normal_())
|
||||||
|
return mu + std * eps
|
||||||
|
|
||||||
|
|
||||||
|
def get_sinusoid_encoding_table(n_position, d_hid):
|
||||||
|
def get_position_angle_vec(position):
|
||||||
|
return [
|
||||||
|
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
|
||||||
|
for hid_j in range(d_hid)
|
||||||
|
]
|
||||||
|
|
||||||
|
sinusoid_table = np.array(
|
||||||
|
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
|
||||||
|
)
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class DETRVAE(nn.Module):
|
||||||
|
"""This is the DETR module that performs object detection"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
backbones,
|
||||||
|
transformer,
|
||||||
|
encoder,
|
||||||
|
state_dim,
|
||||||
|
num_queries,
|
||||||
|
camera_names,
|
||||||
|
vq,
|
||||||
|
vq_class,
|
||||||
|
vq_dim,
|
||||||
|
action_dim,
|
||||||
|
):
|
||||||
|
"""Initializes the model.
|
||||||
|
Parameters:
|
||||||
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
|
transformer: torch module of the transformer architecture. See transformer.py
|
||||||
|
state_dim: robot state dimension of the environment
|
||||||
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||||
|
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_queries = num_queries
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.transformer = transformer
|
||||||
|
self.encoder = encoder
|
||||||
|
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
|
||||||
|
self.state_dim, self.action_dim = state_dim, action_dim
|
||||||
|
hidden_dim = transformer.d_model
|
||||||
|
self.action_head = nn.Linear(hidden_dim, action_dim)
|
||||||
|
self.is_pad_head = nn.Linear(hidden_dim, 1)
|
||||||
|
self.query_embed = nn.Embedding(num_queries, hidden_dim)
|
||||||
|
if backbones is not None:
|
||||||
|
self.input_proj = nn.Conv2d(
|
||||||
|
backbones[0].num_channels, hidden_dim, kernel_size=1
|
||||||
|
)
|
||||||
|
self.backbones = nn.ModuleList(backbones)
|
||||||
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
|
else:
|
||||||
|
# input_dim = 14 + 7 # robot_state + env_state
|
||||||
|
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
|
||||||
|
self.input_proj_env_state = nn.Linear(7, hidden_dim)
|
||||||
|
self.pos = torch.nn.Embedding(2, hidden_dim)
|
||||||
|
self.backbones = None
|
||||||
|
|
||||||
|
# encoder extra parameters
|
||||||
|
self.latent_dim = 32 # final size of latent z # TODO tune
|
||||||
|
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
|
||||||
|
self.encoder_action_proj = nn.Linear(
|
||||||
|
action_dim, hidden_dim
|
||||||
|
) # project action to embedding
|
||||||
|
self.encoder_joint_proj = nn.Linear(
|
||||||
|
action_dim, hidden_dim
|
||||||
|
) # project qpos to embedding
|
||||||
|
if self.vq:
|
||||||
|
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
|
||||||
|
else:
|
||||||
|
self.latent_proj = nn.Linear(
|
||||||
|
hidden_dim, self.latent_dim * 2
|
||||||
|
) # project hidden state to latent std, var
|
||||||
|
self.register_buffer(
|
||||||
|
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
|
||||||
|
) # [CLS], qpos, a_seq
|
||||||
|
|
||||||
|
# decoder extra parameters
|
||||||
|
if self.vq:
|
||||||
|
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
|
||||||
|
else:
|
||||||
|
self.latent_out_proj = nn.Linear(
|
||||||
|
self.latent_dim, hidden_dim
|
||||||
|
) # project latent sample to embedding
|
||||||
|
self.additional_pos_embed = nn.Embedding(
|
||||||
|
2, hidden_dim
|
||||||
|
) # learned position embedding for proprio and latent
|
||||||
|
|
||||||
|
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
|
||||||
|
bs, _ = qpos.shape
|
||||||
|
if self.encoder is None:
|
||||||
|
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
|
||||||
|
qpos.device
|
||||||
|
)
|
||||||
|
latent_input = self.latent_out_proj(latent_sample)
|
||||||
|
probs = binaries = mu = logvar = None
|
||||||
|
else:
|
||||||
|
# cvae encoder
|
||||||
|
is_training = actions is not None # train or val
|
||||||
|
### Obtain latent z from action sequence
|
||||||
|
if is_training:
|
||||||
|
# project action sequence to embedding dim, and concat with a CLS token
|
||||||
|
action_embed = self.encoder_action_proj(
|
||||||
|
actions
|
||||||
|
) # (bs, seq, hidden_dim)
|
||||||
|
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
|
||||||
|
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
|
||||||
|
cls_embed = self.cls_embed.weight # (1, hidden_dim)
|
||||||
|
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
|
||||||
|
bs, 1, 1
|
||||||
|
) # (bs, 1, hidden_dim)
|
||||||
|
encoder_input = torch.cat(
|
||||||
|
[cls_embed, qpos_embed, action_embed], axis=1
|
||||||
|
) # (bs, seq+1, hidden_dim)
|
||||||
|
encoder_input = encoder_input.permute(
|
||||||
|
1, 0, 2
|
||||||
|
) # (seq+1, bs, hidden_dim)
|
||||||
|
# do not mask cls token
|
||||||
|
cls_joint_is_pad = torch.full((bs, 2), False).to(
|
||||||
|
qpos.device
|
||||||
|
) # False: not a padding
|
||||||
|
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
|
||||||
|
# obtain position embedding
|
||||||
|
pos_embed = self.pos_table.clone().detach()
|
||||||
|
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
|
||||||
|
# query model
|
||||||
|
encoder_output = self.encoder(
|
||||||
|
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
|
||||||
|
)
|
||||||
|
encoder_output = encoder_output[0] # take cls output only
|
||||||
|
latent_info = self.latent_proj(encoder_output)
|
||||||
|
|
||||||
|
if self.vq:
|
||||||
|
logits = latent_info.reshape(
|
||||||
|
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
|
||||||
|
)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
binaries = (
|
||||||
|
F.one_hot(
|
||||||
|
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
|
||||||
|
-1
|
||||||
|
),
|
||||||
|
self.vq_dim,
|
||||||
|
)
|
||||||
|
.view(-1, self.vq_class, self.vq_dim)
|
||||||
|
.float()
|
||||||
|
)
|
||||||
|
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
|
||||||
|
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
|
||||||
|
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
|
||||||
|
latent_input = self.latent_out_proj(straigt_through)
|
||||||
|
mu = logvar = None
|
||||||
|
else:
|
||||||
|
probs = binaries = None
|
||||||
|
mu = latent_info[:, : self.latent_dim]
|
||||||
|
logvar = latent_info[:, self.latent_dim :]
|
||||||
|
latent_sample = reparametrize(mu, logvar)
|
||||||
|
latent_input = self.latent_out_proj(latent_sample)
|
||||||
|
|
||||||
|
else:
|
||||||
|
mu = logvar = binaries = probs = None
|
||||||
|
if self.vq:
|
||||||
|
latent_input = self.latent_out_proj(
|
||||||
|
vq_sample.view(-1, self.vq_class * self.vq_dim)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
latent_sample = torch.zeros(
|
||||||
|
[bs, self.latent_dim], dtype=torch.float32
|
||||||
|
).to(qpos.device)
|
||||||
|
latent_input = self.latent_out_proj(latent_sample)
|
||||||
|
|
||||||
|
return latent_input, probs, binaries, mu, logvar
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
qpos: batch, qpos_dim
|
||||||
|
image: batch, num_cam, channel, height, width
|
||||||
|
env_state: None
|
||||||
|
actions: batch, seq, action_dim
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent_input, probs, binaries, mu, logvar = self.encode(
|
||||||
|
qpos, actions, is_pad, vq_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
# cvae decoder
|
||||||
|
if self.backbones is not None:
|
||||||
|
# Image observation features and position embeddings
|
||||||
|
all_cam_features = []
|
||||||
|
all_cam_pos = []
|
||||||
|
for cam_id, cam_name in enumerate(self.camera_names):
|
||||||
|
# TODO: fix this error
|
||||||
|
features, pos = self.backbones[0](image[:, cam_id])
|
||||||
|
features = features[0] # take the last layer feature
|
||||||
|
pos = pos[0]
|
||||||
|
all_cam_features.append(self.input_proj(features))
|
||||||
|
all_cam_pos.append(pos)
|
||||||
|
# proprioception features
|
||||||
|
proprio_input = self.input_proj_robot_state(qpos)
|
||||||
|
# fold camera dimension into width dimension
|
||||||
|
src = torch.cat(all_cam_features, axis=3)
|
||||||
|
pos = torch.cat(all_cam_pos, axis=3)
|
||||||
|
hs = self.transformer(
|
||||||
|
src,
|
||||||
|
None,
|
||||||
|
self.query_embed.weight,
|
||||||
|
pos,
|
||||||
|
latent_input,
|
||||||
|
proprio_input,
|
||||||
|
self.additional_pos_embed.weight,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
qpos = self.input_proj_robot_state(qpos)
|
||||||
|
env_state = self.input_proj_env_state(env_state)
|
||||||
|
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
|
||||||
|
hs = self.transformer(
|
||||||
|
transformer_input, None, self.query_embed.weight, self.pos.weight
|
||||||
|
)[0]
|
||||||
|
a_hat = self.action_head(hs)
|
||||||
|
is_pad_hat = self.is_pad_head(hs)
|
||||||
|
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
|
||||||
|
|
||||||
|
|
||||||
|
class CNNMLP(nn.Module):
|
||||||
|
def __init__(self, backbones, state_dim, camera_names):
|
||||||
|
"""Initializes the model.
|
||||||
|
Parameters:
|
||||||
|
backbones: torch module of the backbone to be used. See backbone.py
|
||||||
|
transformer: torch module of the transformer architecture. See transformer.py
|
||||||
|
state_dim: robot state dimension of the environment
|
||||||
|
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
|
||||||
|
DETR can detect in a single image. For COCO, we recommend 100 queries.
|
||||||
|
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.action_head = nn.Linear(1000, state_dim) # TODO add more
|
||||||
|
if backbones is not None:
|
||||||
|
self.backbones = nn.ModuleList(backbones)
|
||||||
|
backbone_down_projs = []
|
||||||
|
for backbone in backbones:
|
||||||
|
down_proj = nn.Sequential(
|
||||||
|
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
|
||||||
|
nn.Conv2d(128, 64, kernel_size=5),
|
||||||
|
nn.Conv2d(64, 32, kernel_size=5),
|
||||||
|
)
|
||||||
|
backbone_down_projs.append(down_proj)
|
||||||
|
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
|
||||||
|
|
||||||
|
mlp_in_dim = 768 * len(backbones) + state_dim
|
||||||
|
self.mlp = mlp(
|
||||||
|
input_dim=mlp_in_dim,
|
||||||
|
hidden_dim=1024,
|
||||||
|
output_dim=self.action_dim,
|
||||||
|
hidden_depth=2,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, qpos, image, env_state, actions=None):
|
||||||
|
"""
|
||||||
|
qpos: batch, qpos_dim
|
||||||
|
image: batch, num_cam, channel, height, width
|
||||||
|
env_state: None
|
||||||
|
actions: batch, seq, action_dim
|
||||||
|
"""
|
||||||
|
is_training = actions is not None # train or val
|
||||||
|
bs, _ = qpos.shape
|
||||||
|
# Image observation features and position embeddings
|
||||||
|
all_cam_features = []
|
||||||
|
for cam_id, cam_name in enumerate(self.camera_names):
|
||||||
|
features, pos = self.backbones[cam_id](image[:, cam_id])
|
||||||
|
features = features[0] # take the last layer feature
|
||||||
|
pos = pos[0] # not used
|
||||||
|
all_cam_features.append(self.backbone_down_projs[cam_id](features))
|
||||||
|
# flatten everything
|
||||||
|
flattened_features = []
|
||||||
|
for cam_feature in all_cam_features:
|
||||||
|
flattened_features.append(cam_feature.reshape([bs, -1]))
|
||||||
|
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
|
||||||
|
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
|
||||||
|
a_hat = self.mlp(features)
|
||||||
|
return a_hat
|
||||||
|
|
||||||
|
|
||||||
|
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
|
||||||
|
if hidden_depth == 0:
|
||||||
|
mods = [nn.Linear(input_dim, output_dim)]
|
||||||
|
else:
|
||||||
|
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||||
|
for i in range(hidden_depth - 1):
|
||||||
|
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
|
||||||
|
mods.append(nn.Linear(hidden_dim, output_dim))
|
||||||
|
trunk = nn.Sequential(*mods)
|
||||||
|
return trunk
|
||||||
|
|
||||||
|
|
||||||
|
def build_encoder(
|
||||||
|
hidden_dim, # 256
|
||||||
|
dropout, # 0.1
|
||||||
|
nheads, # 8
|
||||||
|
dim_feedforward,
|
||||||
|
num_encoder_layers, # 4 # TODO shared with VAE decoder
|
||||||
|
normalize_before, # False
|
||||||
|
):
|
||||||
|
activation = "relu"
|
||||||
|
|
||||||
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
|
||||||
|
)
|
||||||
|
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
|
||||||
|
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
||||||
|
|
||||||
|
return encoder
|
||||||
|
|
||||||
|
|
||||||
|
def build_vae(
|
||||||
|
hidden_dim,
|
||||||
|
state_dim,
|
||||||
|
position_embedding_type,
|
||||||
|
lr_backbone,
|
||||||
|
masks,
|
||||||
|
backbone,
|
||||||
|
dilation,
|
||||||
|
dropout,
|
||||||
|
nheads,
|
||||||
|
dim_feedforward,
|
||||||
|
enc_layers,
|
||||||
|
dec_layers,
|
||||||
|
pre_norm,
|
||||||
|
num_queries,
|
||||||
|
camera_names,
|
||||||
|
vq,
|
||||||
|
vq_class,
|
||||||
|
vq_dim,
|
||||||
|
action_dim,
|
||||||
|
no_encoder,
|
||||||
|
):
|
||||||
|
# TODO hardcode
|
||||||
|
|
||||||
|
# From state
|
||||||
|
# backbone = None # from state for now, no need for conv nets
|
||||||
|
# From image
|
||||||
|
backbones = []
|
||||||
|
backbone = build_backbone(
|
||||||
|
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
|
||||||
|
)
|
||||||
|
backbones.append(backbone)
|
||||||
|
|
||||||
|
transformer = build_transformer(
|
||||||
|
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
if no_encoder:
|
||||||
|
encoder = None
|
||||||
|
else:
|
||||||
|
encoder = build_encoder(
|
||||||
|
hidden_dim,
|
||||||
|
dropout,
|
||||||
|
nheads,
|
||||||
|
dim_feedforward,
|
||||||
|
enc_layers,
|
||||||
|
pre_norm,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = DETRVAE(
|
||||||
|
backbones,
|
||||||
|
transformer,
|
||||||
|
encoder,
|
||||||
|
state_dim,
|
||||||
|
num_queries,
|
||||||
|
camera_names,
|
||||||
|
vq,
|
||||||
|
vq_class,
|
||||||
|
vq_dim,
|
||||||
|
action_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
# TODO
|
||||||
|
def build_cnnmlp(args):
|
||||||
|
state_dim = 14 # TODO hardcode
|
||||||
|
|
||||||
|
# From state
|
||||||
|
# backbone = None # from state for now, no need for conv nets
|
||||||
|
# From image
|
||||||
|
backbones = []
|
||||||
|
for _ in args.camera_names:
|
||||||
|
backbone = build_backbone(args)
|
||||||
|
backbones.append(backbone)
|
||||||
|
|
||||||
|
model = CNNMLP(
|
||||||
|
backbones,
|
||||||
|
state_dim=state_dim,
|
||||||
|
camera_names=args.camera_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
|
||||||
|
|
||||||
|
return model
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import torch
|
||||||
|
|
||||||
|
DROPOUT_RATE = 0.1 # 定义 dropout 率
|
||||||
|
|
||||||
|
# 定义一个因果变压器块
|
||||||
|
class Causal_Transformer_Block(nn.Module):
|
||||||
|
def __init__(self, seq_len, latent_dim, num_head) -> None:
|
||||||
|
"""
|
||||||
|
初始化因果变压器块
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_len (int): 序列长度
|
||||||
|
latent_dim (int): 潜在维度
|
||||||
|
num_head (int): 注意力头的数量
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.num_head = num_head
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化
|
||||||
|
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制
|
||||||
|
self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(latent_dim, 4 * latent_dim), # 全连接层
|
||||||
|
nn.GELU(), # GELU 激活函数
|
||||||
|
nn.Linear(4 * latent_dim, latent_dim), # 全连接层
|
||||||
|
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
前向传播
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): 输入张量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 输出张量
|
||||||
|
"""
|
||||||
|
# 创建上三角掩码,防止信息泄露
|
||||||
|
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
|
||||||
|
x = self.ln_1(x) # 层归一化
|
||||||
|
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出
|
||||||
|
x = self.ln_2(x) # 层归一化
|
||||||
|
x = x + self.mlp(x) # 加上 MLP 输出
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
# 使用自注意力机制而不是 RNN 来建模潜在空间序列
|
||||||
|
class Latent_Model_Transformer(nn.Module):
|
||||||
|
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
|
||||||
|
"""
|
||||||
|
初始化潜在模型变压器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_dim (int): 输入维度
|
||||||
|
output_dim (int): 输出维度
|
||||||
|
seq_len (int): 序列长度
|
||||||
|
latent_dim (int, optional): 潜在维度,默认值为 256
|
||||||
|
num_head (int, optional): 注意力头的数量,默认值为 8
|
||||||
|
num_layer (int, optional): 变压器层的数量,默认值为 3
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.num_head = num_head
|
||||||
|
self.num_layer = num_layer
|
||||||
|
self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层
|
||||||
|
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入
|
||||||
|
self.attention_blocks = nn.Sequential(
|
||||||
|
nn.Dropout(DROPOUT_RATE), # Dropout
|
||||||
|
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块
|
||||||
|
nn.LayerNorm(latent_dim) # 层归一化
|
||||||
|
)
|
||||||
|
self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
前向传播
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): 输入张量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 输出张量
|
||||||
|
"""
|
||||||
|
x = self.input_layer(x) # 输入层
|
||||||
|
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入
|
||||||
|
x = self.attention_blocks(x) # 通过注意力块
|
||||||
|
logits = self.output_layer(x) # 输出层
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(self, n, temperature=0.1, x=None):
|
||||||
|
"""
|
||||||
|
生成序列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n (int): 生成序列的数量
|
||||||
|
temperature (float, optional): 采样温度,默认值为 0.1
|
||||||
|
x (torch.Tensor, optional): 初始输入张量,默认值为 None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 生成的序列
|
||||||
|
"""
|
||||||
|
if x is None:
|
||||||
|
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入
|
||||||
|
for i in range(self.seq_len):
|
||||||
|
logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出
|
||||||
|
probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布
|
||||||
|
samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样
|
||||||
|
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码
|
||||||
|
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中
|
||||||
|
|
||||||
|
return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入)
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Various positional encodings for the transformer.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from shadow_act.utils.misc import NestedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSine(nn.Module):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperature = temperature
|
||||||
|
self.normalize = normalize
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, tensor):
|
||||||
|
x = tensor
|
||||||
|
# mask = tensor_list.mask
|
||||||
|
# assert mask is not None
|
||||||
|
# not_mask = ~mask
|
||||||
|
|
||||||
|
not_mask = torch.ones_like(x[0, [0]])
|
||||||
|
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||||
|
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-6
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||||
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
||||||
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingLearned(nn.Module):
|
||||||
|
"""
|
||||||
|
Absolute pos embedding, learned.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_pos_feats=256):
|
||||||
|
super().__init__()
|
||||||
|
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
nn.init.uniform_(self.row_embed.weight)
|
||||||
|
nn.init.uniform_(self.col_embed.weight)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
i = torch.arange(w, device=x.device)
|
||||||
|
j = torch.arange(h, device=x.device)
|
||||||
|
x_emb = self.col_embed(i)
|
||||||
|
y_emb = self.row_embed(j)
|
||||||
|
pos = torch.cat([
|
||||||
|
x_emb.unsqueeze(0).repeat(h, 1, 1),
|
||||||
|
y_emb.unsqueeze(1).repeat(1, w, 1),
|
||||||
|
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_position_encoding(hidden_dim, position_embedding_type):
|
||||||
|
N_steps = hidden_dim // 2
|
||||||
|
if position_embedding_type in ('v2', 'sine'):
|
||||||
|
# TODO find a better way of exposing other arguments
|
||||||
|
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||||
|
elif position_embedding_type in ('v3', 'learned'):
|
||||||
|
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not supported {position_embedding_type}")
|
||||||
|
|
||||||
|
return position_embedding
|
||||||
@@ -0,0 +1,424 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
DETR Transformer class.
|
||||||
|
|
||||||
|
Copy-paste from torch.nn.Transformer with modifications:
|
||||||
|
* positional encodings are passed in MHattention
|
||||||
|
* extra LN at the end of encoder is removed
|
||||||
|
* decoder returns a stack of activations from all decoding layers
|
||||||
|
"""
|
||||||
|
import copy
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model=512,
|
||||||
|
nhead=8,
|
||||||
|
num_encoder_layers=6,
|
||||||
|
num_decoder_layers=6,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
return_intermediate_dec=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
encoder_layer = TransformerEncoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||||
|
)
|
||||||
|
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
||||||
|
self.encoder = TransformerEncoder(
|
||||||
|
encoder_layer, num_encoder_layers, encoder_norm
|
||||||
|
)
|
||||||
|
|
||||||
|
decoder_layer = TransformerDecoderLayer(
|
||||||
|
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
||||||
|
)
|
||||||
|
decoder_norm = nn.LayerNorm(d_model)
|
||||||
|
self.decoder = TransformerDecoder(
|
||||||
|
decoder_layer,
|
||||||
|
num_decoder_layers,
|
||||||
|
decoder_norm,
|
||||||
|
return_intermediate=return_intermediate_dec,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.nhead = nhead
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
for p in self.parameters():
|
||||||
|
if p.dim() > 1:
|
||||||
|
nn.init.xavier_uniform_(p)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src,
|
||||||
|
mask,
|
||||||
|
query_embed,
|
||||||
|
pos_embed,
|
||||||
|
latent_input=None,
|
||||||
|
proprio_input=None,
|
||||||
|
additional_pos_embed=None,
|
||||||
|
):
|
||||||
|
# TODO flatten only when input has H and W
|
||||||
|
if len(src.shape) == 4: # has H and W
|
||||||
|
# flatten NxCxHxW to HWxNxC
|
||||||
|
bs, c, h, w = src.shape
|
||||||
|
src = src.flatten(2).permute(2, 0, 1)
|
||||||
|
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
|
||||||
|
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||||
|
# mask = mask.flatten(1)
|
||||||
|
|
||||||
|
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
|
||||||
|
1, bs, 1
|
||||||
|
) # seq, bs, dim
|
||||||
|
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
|
||||||
|
|
||||||
|
addition_input = torch.stack([latent_input, proprio_input], axis=0)
|
||||||
|
src = torch.cat([addition_input, src], axis=0)
|
||||||
|
else:
|
||||||
|
assert len(src.shape) == 3
|
||||||
|
# flatten NxHWxC to HWxNxC
|
||||||
|
bs, hw, c = src.shape
|
||||||
|
src = src.permute(1, 0, 2)
|
||||||
|
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||||
|
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
|
||||||
|
|
||||||
|
tgt = torch.zeros_like(query_embed)
|
||||||
|
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
||||||
|
hs = self.decoder(
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
memory_key_padding_mask=mask,
|
||||||
|
pos=pos_embed,
|
||||||
|
query_pos=query_embed,
|
||||||
|
)
|
||||||
|
hs = hs.transpose(1, 2)
|
||||||
|
return hs
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = _get_clones(encoder_layer, num_layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.norm = norm
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src,
|
||||||
|
mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
output = src
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
output = layer(
|
||||||
|
output,
|
||||||
|
src_mask=mask,
|
||||||
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
pos=pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
||||||
|
super().__init__()
|
||||||
|
self.layers = _get_clones(decoder_layer, num_layers)
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.norm = norm
|
||||||
|
self.return_intermediate = return_intermediate
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask: Optional[Tensor] = None,
|
||||||
|
memory_mask: Optional[Tensor] = None,
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
query_pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
output = tgt
|
||||||
|
|
||||||
|
intermediate = []
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
output = layer(
|
||||||
|
output,
|
||||||
|
memory,
|
||||||
|
tgt_mask=tgt_mask,
|
||||||
|
memory_mask=memory_mask,
|
||||||
|
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
pos=pos,
|
||||||
|
query_pos=query_pos,
|
||||||
|
)
|
||||||
|
if self.return_intermediate:
|
||||||
|
intermediate.append(self.norm(output))
|
||||||
|
|
||||||
|
if self.norm is not None:
|
||||||
|
output = self.norm(output)
|
||||||
|
if self.return_intermediate:
|
||||||
|
intermediate.pop()
|
||||||
|
intermediate.append(output)
|
||||||
|
|
||||||
|
if self.return_intermediate:
|
||||||
|
return torch.stack(intermediate)
|
||||||
|
|
||||||
|
return output.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerEncoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||||
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
|
def forward_post(
|
||||||
|
self,
|
||||||
|
src,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
q = k = self.with_pos_embed(src, pos)
|
||||||
|
src2 = self.self_attn(
|
||||||
|
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||||
|
)[0]
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
|
src = self.norm1(src)
|
||||||
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||||
|
src = src + self.dropout2(src2)
|
||||||
|
src = self.norm2(src)
|
||||||
|
return src
|
||||||
|
|
||||||
|
def forward_pre(
|
||||||
|
self,
|
||||||
|
src,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
src2 = self.norm1(src)
|
||||||
|
q = k = self.with_pos_embed(src2, pos)
|
||||||
|
src2 = self.self_attn(
|
||||||
|
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
||||||
|
)[0]
|
||||||
|
src = src + self.dropout1(src2)
|
||||||
|
src2 = self.norm2(src)
|
||||||
|
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
||||||
|
src = src + self.dropout2(src2)
|
||||||
|
return src
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
src,
|
||||||
|
src_mask: Optional[Tensor] = None,
|
||||||
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.normalize_before:
|
||||||
|
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
||||||
|
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
||||||
|
|
||||||
|
|
||||||
|
class TransformerDecoderLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
nhead,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
# Implementation of Feedforward model
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2 = nn.LayerNorm(d_model)
|
||||||
|
self.norm3 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout1 = nn.Dropout(dropout)
|
||||||
|
self.dropout2 = nn.Dropout(dropout)
|
||||||
|
self.dropout3 = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
|
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||||
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
|
def forward_post(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask: Optional[Tensor] = None,
|
||||||
|
memory_mask: Optional[Tensor] = None,
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
query_pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
q = k = self.with_pos_embed(tgt, query_pos)
|
||||||
|
tgt2 = self.self_attn(
|
||||||
|
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||||
|
)[0]
|
||||||
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
|
tgt = self.norm1(tgt)
|
||||||
|
tgt2 = self.multihead_attn(
|
||||||
|
query=self.with_pos_embed(tgt, query_pos),
|
||||||
|
key=self.with_pos_embed(memory, pos),
|
||||||
|
value=memory,
|
||||||
|
attn_mask=memory_mask,
|
||||||
|
key_padding_mask=memory_key_padding_mask,
|
||||||
|
)[0]
|
||||||
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
|
tgt = self.norm2(tgt)
|
||||||
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||||
|
tgt = tgt + self.dropout3(tgt2)
|
||||||
|
tgt = self.norm3(tgt)
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
def forward_pre(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask: Optional[Tensor] = None,
|
||||||
|
memory_mask: Optional[Tensor] = None,
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
query_pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
tgt2 = self.norm1(tgt)
|
||||||
|
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||||
|
tgt2 = self.self_attn(
|
||||||
|
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
||||||
|
)[0]
|
||||||
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
|
tgt2 = self.norm2(tgt)
|
||||||
|
tgt2 = self.multihead_attn(
|
||||||
|
query=self.with_pos_embed(tgt2, query_pos),
|
||||||
|
key=self.with_pos_embed(memory, pos),
|
||||||
|
value=memory,
|
||||||
|
attn_mask=memory_mask,
|
||||||
|
key_padding_mask=memory_key_padding_mask,
|
||||||
|
)[0]
|
||||||
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
|
tgt2 = self.norm3(tgt)
|
||||||
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||||
|
tgt = tgt + self.dropout3(tgt2)
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask: Optional[Tensor] = None,
|
||||||
|
memory_mask: Optional[Tensor] = None,
|
||||||
|
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
pos: Optional[Tensor] = None,
|
||||||
|
query_pos: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
if self.normalize_before:
|
||||||
|
return self.forward_pre(
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask,
|
||||||
|
memory_mask,
|
||||||
|
tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask,
|
||||||
|
pos,
|
||||||
|
query_pos,
|
||||||
|
)
|
||||||
|
return self.forward_post(
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask,
|
||||||
|
memory_mask,
|
||||||
|
tgt_key_padding_mask,
|
||||||
|
memory_key_padding_mask,
|
||||||
|
pos,
|
||||||
|
query_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clones(module, N):
|
||||||
|
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
|
||||||
|
def build_transformer(
|
||||||
|
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
|
||||||
|
):
|
||||||
|
return Transformer(
|
||||||
|
d_model=hidden_dim,
|
||||||
|
dropout=dropout,
|
||||||
|
nhead=nheads,
|
||||||
|
dim_feedforward=dim_feedforward,
|
||||||
|
num_encoder_layers=enc_layers,
|
||||||
|
num_decoder_layers=dec_layers,
|
||||||
|
normalize_before=pre_norm,
|
||||||
|
return_intermediate_dec=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_activation_fn(activation):
|
||||||
|
"""Return an activation function given a string"""
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
if activation == "gelu":
|
||||||
|
return F.gelu
|
||||||
|
if activation == "glu":
|
||||||
|
return F.glu
|
||||||
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1,522 @@
|
|||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from shadow_act.models.detr_vae import build_vae, build_cnnmlp
|
||||||
|
|
||||||
|
# from diffusers.training_utils import EMAModel
|
||||||
|
# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
|
||||||
|
# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D
|
||||||
|
|
||||||
|
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
|
||||||
|
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
|
||||||
|
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# TODO: 重构DiffusionPolicy类
|
||||||
|
class DiffusionPolicy(nn.Module):
|
||||||
|
def __init__(self, args_override):
|
||||||
|
"""
|
||||||
|
初始化DiffusionPolicy类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_override (dict): 参数覆盖字典
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.camera_names = args_override["camera_names"]
|
||||||
|
self.observation_horizon = args_override["observation_horizon"]
|
||||||
|
self.action_horizon = args_override["action_horizon"]
|
||||||
|
self.prediction_horizon = args_override["prediction_horizon"]
|
||||||
|
self.num_inference_timesteps = args_override["num_inference_timesteps"]
|
||||||
|
self.ema_power = args_override["ema_power"]
|
||||||
|
self.lr = args_override["lr"]
|
||||||
|
self.weight_decay = 0
|
||||||
|
|
||||||
|
self.num_kp = 32
|
||||||
|
self.feature_dimension = 64
|
||||||
|
self.ac_dim = args_override["action_dim"]
|
||||||
|
self.obs_dim = self.feature_dimension * len(self.camera_names) + 14
|
||||||
|
|
||||||
|
backbones = []
|
||||||
|
pools = []
|
||||||
|
linears = []
|
||||||
|
for _ in self.camera_names:
|
||||||
|
backbones.append(
|
||||||
|
ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False)
|
||||||
|
)
|
||||||
|
pools.append(
|
||||||
|
SpatialSoftmax(
|
||||||
|
input_shape=[512, 15, 20],
|
||||||
|
num_kp=self.num_kp,
|
||||||
|
temperature=1.0,
|
||||||
|
learnable_temperature=False,
|
||||||
|
noise_std=0.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
linears.append(
|
||||||
|
torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension)
|
||||||
|
)
|
||||||
|
backbones = nn.ModuleList(backbones)
|
||||||
|
pools = nn.ModuleList(pools)
|
||||||
|
linears = nn.ModuleList(linears)
|
||||||
|
|
||||||
|
backbones = replace_bn_with_gn(backbones)
|
||||||
|
|
||||||
|
noise_pred_net = ConditionalUnet1D(
|
||||||
|
input_dim=self.ac_dim,
|
||||||
|
global_cond_dim=self.obs_dim * self.observation_horizon,
|
||||||
|
)
|
||||||
|
|
||||||
|
nets = nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"policy": nn.ModuleDict(
|
||||||
|
{
|
||||||
|
"backbones": backbones,
|
||||||
|
"pools": pools,
|
||||||
|
"linears": linears,
|
||||||
|
"noise_pred_net": noise_pred_net,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
nets = nets.float().cuda()
|
||||||
|
ENABLE_EMA = True
|
||||||
|
if ENABLE_EMA:
|
||||||
|
ema = EMAModel(model=nets, power=self.ema_power)
|
||||||
|
else:
|
||||||
|
ema = None
|
||||||
|
self.nets = nets
|
||||||
|
self.ema = ema
|
||||||
|
|
||||||
|
# 设置噪声调度器
|
||||||
|
self.noise_scheduler = DDIMScheduler(
|
||||||
|
num_train_timesteps=50,
|
||||||
|
beta_schedule="squaredcos_cap_v2",
|
||||||
|
clip_sample=True,
|
||||||
|
set_alpha_to_one=True,
|
||||||
|
steps_offset=0,
|
||||||
|
prediction_type="epsilon",
|
||||||
|
)
|
||||||
|
|
||||||
|
n_parameters = sum(p.numel() for p in self.parameters())
|
||||||
|
logger.info("number of parameters: %.2fM", n_parameters / 1e6)
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""
|
||||||
|
配置优化器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optimizer: 配置的优化器
|
||||||
|
"""
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay
|
||||||
|
)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
"""
|
||||||
|
前向传播函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qpos (torch.Tensor): 位置张量
|
||||||
|
image (torch.Tensor): 图像张量
|
||||||
|
actions (torch.Tensor, optional): 动作张量
|
||||||
|
is_pad (torch.Tensor, optional): 填充张量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 损失字典(训练时)
|
||||||
|
torch.Tensor: 动作张量(推理时)
|
||||||
|
"""
|
||||||
|
B = qpos.shape[0]
|
||||||
|
if actions is not None: # 训练时
|
||||||
|
nets = self.nets
|
||||||
|
all_features = []
|
||||||
|
for cam_id in range(len(self.camera_names)):
|
||||||
|
cam_image = image[:, cam_id]
|
||||||
|
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||||
|
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||||
|
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||||
|
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||||
|
all_features.append(out_features)
|
||||||
|
|
||||||
|
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||||
|
|
||||||
|
# 为动作添加噪声
|
||||||
|
noise = torch.randn(actions.shape, device=obs_cond.device)
|
||||||
|
|
||||||
|
# 为每个数据点采样一个扩散迭代
|
||||||
|
timesteps = torch.randint(
|
||||||
|
0,
|
||||||
|
self.noise_scheduler.config.num_train_timesteps,
|
||||||
|
(B,),
|
||||||
|
device=obs_cond.device,
|
||||||
|
).long()
|
||||||
|
|
||||||
|
# 根据每个扩散迭代的噪声幅度向干净动作添加噪声
|
||||||
|
noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
|
||||||
|
|
||||||
|
# 预测噪声残差
|
||||||
|
noise_pred = nets["policy"]["noise_pred_net"](
|
||||||
|
noisy_actions, timesteps, global_cond=obs_cond
|
||||||
|
)
|
||||||
|
|
||||||
|
# L2损失
|
||||||
|
all_l2 = F.mse_loss(noise_pred, noise, reduction="none")
|
||||||
|
loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
|
||||||
|
|
||||||
|
loss_dict = {}
|
||||||
|
loss_dict["l2_loss"] = loss
|
||||||
|
loss_dict["loss"] = loss
|
||||||
|
|
||||||
|
if self.training and self.ema is not None:
|
||||||
|
self.ema.step(nets)
|
||||||
|
return loss_dict
|
||||||
|
else: # 推理时
|
||||||
|
To = self.observation_horizon
|
||||||
|
Ta = self.action_horizon
|
||||||
|
Tp = self.prediction_horizon
|
||||||
|
action_dim = self.ac_dim
|
||||||
|
|
||||||
|
nets = self.nets
|
||||||
|
if self.ema is not None:
|
||||||
|
nets = self.ema.averaged_model
|
||||||
|
|
||||||
|
all_features = []
|
||||||
|
for cam_id in range(len(self.camera_names)):
|
||||||
|
cam_image = image[:, cam_id]
|
||||||
|
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
|
||||||
|
pool_features = nets["policy"]["pools"][cam_id](cam_features)
|
||||||
|
pool_features = torch.flatten(pool_features, start_dim=1)
|
||||||
|
out_features = nets["policy"]["linears"][cam_id](pool_features)
|
||||||
|
all_features.append(out_features)
|
||||||
|
|
||||||
|
obs_cond = torch.cat(all_features + [qpos], dim=1)
|
||||||
|
|
||||||
|
# 从高斯噪声初始化动作
|
||||||
|
noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device)
|
||||||
|
naction = noisy_action
|
||||||
|
|
||||||
|
# 初始化调度器
|
||||||
|
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
|
||||||
|
|
||||||
|
for k in self.noise_scheduler.timesteps:
|
||||||
|
# 预测噪声
|
||||||
|
noise_pred = nets["policy"]["noise_pred_net"](
|
||||||
|
sample=naction, timestep=k, global_cond=obs_cond
|
||||||
|
)
|
||||||
|
|
||||||
|
# 逆扩散步骤(去除噪声)
|
||||||
|
naction = self.noise_scheduler.step(
|
||||||
|
model_output=noise_pred, timestep=k, sample=naction
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
return naction
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
"""
|
||||||
|
序列化模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 模型状态字典
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"nets": self.nets.state_dict(),
|
||||||
|
"ema": (
|
||||||
|
self.ema.averaged_model.state_dict() if self.ema is not None else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def deserialize(self, model_dict):
|
||||||
|
"""
|
||||||
|
反序列化模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dict (dict): 模型状态字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
status: 加载状态
|
||||||
|
"""
|
||||||
|
status = self.nets.load_state_dict(model_dict["nets"])
|
||||||
|
logger.info("Loaded model")
|
||||||
|
if model_dict.get("ema", None) is not None:
|
||||||
|
logger.info("Loaded EMA")
|
||||||
|
status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"])
|
||||||
|
status = [status, status_ema]
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
class ACTPolicy(nn.Module):
|
||||||
|
def __init__(self, act_config):
|
||||||
|
"""
|
||||||
|
初始化ACTPolicy类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_override (dict): 参数覆盖字典
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
lr_backbone = act_config["lr_backbone"]
|
||||||
|
vq = act_config["vq"]
|
||||||
|
lr = act_config["lr"]
|
||||||
|
weight_decay = act_config["weight_decay"]
|
||||||
|
|
||||||
|
model = build_vae(
|
||||||
|
act_config["hidden_dim"],
|
||||||
|
act_config["state_dim"],
|
||||||
|
act_config["position_embedding"],
|
||||||
|
lr_backbone,
|
||||||
|
act_config["masks"],
|
||||||
|
act_config["backbone"],
|
||||||
|
act_config["dilation"],
|
||||||
|
act_config["dropout"],
|
||||||
|
act_config["nheads"],
|
||||||
|
act_config["dim_feedforward"],
|
||||||
|
act_config["enc_layers"],
|
||||||
|
act_config["dec_layers"],
|
||||||
|
act_config["pre_norm"],
|
||||||
|
act_config["num_queries"],
|
||||||
|
act_config["camera_names"],
|
||||||
|
vq,
|
||||||
|
act_config["vq_class"],
|
||||||
|
act_config["vq_dim"],
|
||||||
|
act_config["action_dim"],
|
||||||
|
act_config["no_encoder"],
|
||||||
|
)
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
param_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in model.named_parameters()
|
||||||
|
if "backbone" not in n and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in model.named_parameters()
|
||||||
|
if "backbone" in n and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
self.optimizer = torch.optim.AdamW(
|
||||||
|
param_dicts, lr=lr, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
self.model = model # CVAE解码器
|
||||||
|
self.kl_weight = act_config["kl_weight"]
|
||||||
|
self.vq = vq
|
||||||
|
logger.info(f"KL Weight {self.kl_weight}")
|
||||||
|
|
||||||
|
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
|
||||||
|
"""
|
||||||
|
前向传播函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qpos (torch.Tensor): 角度张量
|
||||||
|
image (torch.Tensor): 图像张量
|
||||||
|
actions (torch.Tensor, optional): 动作张量
|
||||||
|
is_pad (torch.Tensor, optional): 填充张量
|
||||||
|
vq_sample (torch.Tensor, optional): VQ样本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 损失字典(训练时)
|
||||||
|
torch.Tensor: 动作张量(推理时)
|
||||||
|
"""
|
||||||
|
env_state = None
|
||||||
|
normalize = transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
image = normalize(image)
|
||||||
|
if actions is not None: # 训练时
|
||||||
|
actions = actions[:, : self.model.num_queries]
|
||||||
|
is_pad = is_pad[:, : self.model.num_queries]
|
||||||
|
|
||||||
|
loss_dict = dict()
|
||||||
|
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(
|
||||||
|
qpos, image, env_state, actions, is_pad
|
||||||
|
)
|
||||||
|
if self.vq or self.model.encoder is None:
|
||||||
|
total_kld = [torch.tensor(0.0)]
|
||||||
|
else:
|
||||||
|
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
|
||||||
|
if self.vq:
|
||||||
|
loss_dict["vq_discrepancy"] = F.l1_loss(
|
||||||
|
probs, binaries, reduction="mean"
|
||||||
|
)
|
||||||
|
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
|
||||||
|
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
|
||||||
|
loss_dict["l1"] = l1
|
||||||
|
loss_dict["kl"] = total_kld[0]
|
||||||
|
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
|
||||||
|
return loss_dict
|
||||||
|
else: # 推理时
|
||||||
|
a_hat, _, (_, _), _, _ = self.model(
|
||||||
|
qpos, image, env_state, vq_sample=vq_sample
|
||||||
|
) # no action, sample from prior
|
||||||
|
return a_hat
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""
|
||||||
|
配置优化器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optimizer: 配置的优化器
|
||||||
|
"""
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def vq_encode(self, qpos, actions, is_pad):
|
||||||
|
"""
|
||||||
|
VQ编码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qpos (torch.Tensor): 位置张量
|
||||||
|
actions (torch.Tensor): 动作张量
|
||||||
|
is_pad (torch.Tensor): 填充张量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: 二进制编码
|
||||||
|
"""
|
||||||
|
actions = actions[:, : self.model.num_queries]
|
||||||
|
is_pad = is_pad[:, : self.model.num_queries]
|
||||||
|
|
||||||
|
_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)
|
||||||
|
|
||||||
|
return binaries
|
||||||
|
|
||||||
|
def serialize(self):
|
||||||
|
"""
|
||||||
|
序列化模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 模型状态字典
|
||||||
|
"""
|
||||||
|
return self.state_dict()
|
||||||
|
|
||||||
|
def deserialize(self, model_dict):
|
||||||
|
"""
|
||||||
|
反序列化模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_dict (dict): 模型状态字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
status: 加载状态
|
||||||
|
"""
|
||||||
|
return self.load_state_dict(model_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class CNNMLPPolicy(nn.Module):
|
||||||
|
def __init__(self, args_override):
|
||||||
|
"""
|
||||||
|
初始化CNNMLPPolicy类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_override (dict): 参数覆盖字典
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
# parser = argparse.ArgumentParser(
|
||||||
|
# "DETR training and evaluation script", parents=[get_args_parser()]
|
||||||
|
# )
|
||||||
|
# args = parser.parse_args()
|
||||||
|
|
||||||
|
# for k, v in args_override.items():
|
||||||
|
# setattr(args, k, v)
|
||||||
|
|
||||||
|
model = build_cnnmlp(args_override)
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
param_dicts = [
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in model.named_parameters()
|
||||||
|
if "backbone" not in n and p.requires_grad
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in model.named_parameters()
|
||||||
|
if "backbone" in n and p.requires_grad
|
||||||
|
],
|
||||||
|
"lr": args_override.lr_backbone,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
self.model = model # 解码器
|
||||||
|
self.optimizer = torch.optim.AdamW(
|
||||||
|
param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, qpos, image, actions=None, is_pad=None):
|
||||||
|
"""
|
||||||
|
前向传播函数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
qpos (torch.Tensor): 位置张量
|
||||||
|
image (torch.Tensor): 图像张量
|
||||||
|
actions (torch.Tensor, optional): 动作张量
|
||||||
|
is_pad (torch.Tensor, optional): 填充张量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 损失字典(训练时)
|
||||||
|
torch.Tensor: 动作张量(推理时)
|
||||||
|
"""
|
||||||
|
env_state = None
|
||||||
|
normalize = transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
image = normalize(image)
|
||||||
|
if actions is not None: # 训练时
|
||||||
|
actions = actions[:, 0]
|
||||||
|
a_hat = self.model(qpos, image, env_state, actions)
|
||||||
|
mse = F.mse_loss(actions, a_hat)
|
||||||
|
loss_dict = dict()
|
||||||
|
loss_dict["mse"] = mse
|
||||||
|
loss_dict["loss"] = loss_dict["mse"]
|
||||||
|
return loss_dict
|
||||||
|
else: # 推理时
|
||||||
|
a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样
|
||||||
|
return a_hat
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
"""
|
||||||
|
配置优化器
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optimizer: 配置的优化器
|
||||||
|
"""
|
||||||
|
return self.optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def kl_divergence(mu, logvar):
|
||||||
|
"""
|
||||||
|
计算KL散度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mu (torch.Tensor): 均值张量
|
||||||
|
logvar (torch.Tensor): 对数方差张量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: 总KL散度,维度KL散度,均值KL散度
|
||||||
|
"""
|
||||||
|
batch_size = mu.size(0)
|
||||||
|
assert batch_size != 0
|
||||||
|
if mu.data.ndimension() == 4:
|
||||||
|
mu = mu.view(mu.size(0), mu.size(1))
|
||||||
|
if logvar.data.ndimension() == 4:
|
||||||
|
logvar = logvar.view(logvar.size(0), logvar.size(1))
|
||||||
|
|
||||||
|
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
|
||||||
|
total_kld = klds.sum(1).mean(0, True)
|
||||||
|
dimension_wise_kld = klds.mean(0)
|
||||||
|
mean_kld = klds.mean(1).mean(0, True)
|
||||||
|
|
||||||
|
return total_kld, dimension_wise_kld, mean_kld
|
||||||
@@ -0,0 +1,245 @@
|
|||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
import pickle
|
||||||
|
import torch
|
||||||
|
# import wandb
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
from copy import deepcopy
|
||||||
|
from itertools import repeat
|
||||||
|
from shadow_act.utils.utils import (
|
||||||
|
set_seed,
|
||||||
|
load_data,
|
||||||
|
compute_dict_mean,
|
||||||
|
)
|
||||||
|
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||||
|
from shadow_act.eval.rm_act_eval import RmActEvaluator
|
||||||
|
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
|
class RmActTrainer:
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
初始化训练器,设置随机种子,加载数据,保存数据统计信息。
|
||||||
|
"""
|
||||||
|
self._config = config
|
||||||
|
self._num_steps = config["num_steps"]
|
||||||
|
self._ckpt_dir = config["checkpoint_dir"]
|
||||||
|
self._state_dim = config["state_dim"]
|
||||||
|
self._real_robot = config["real_robot"]
|
||||||
|
self._policy_class = config["policy_class"]
|
||||||
|
self._onscreen_render = config["onscreen_render"]
|
||||||
|
self._policy_config = config["policy_config"]
|
||||||
|
self._camera_names = config["camera_names"]
|
||||||
|
self._max_timesteps = config["episode_len"]
|
||||||
|
self._task_name = config["task_name"]
|
||||||
|
self._temporal_agg = config["temporal_agg"]
|
||||||
|
self._onscreen_cam = "angle"
|
||||||
|
self._vq = config["policy_config"]["vq"]
|
||||||
|
self._batch_size = config["batch_size"]
|
||||||
|
|
||||||
|
self._seed = config["seed"]
|
||||||
|
self._eval_every = config["eval_every"]
|
||||||
|
self._validate_every = config["validate_every"]
|
||||||
|
self._save_every = config["save_every"]
|
||||||
|
self._load_pretrain = config["load_pretrain"]
|
||||||
|
self._resume_ckpt_path = config["resume_ckpt_path"]
|
||||||
|
|
||||||
|
if config["name_filter"] is None:
|
||||||
|
name_filter = lambda n : True
|
||||||
|
else:
|
||||||
|
name_filter = config["name_filter"]
|
||||||
|
|
||||||
|
self._eval = RmActEvaluator(config, True, 50)
|
||||||
|
# 加载数据
|
||||||
|
self._train_dataloader, self._val_dataloader, self._stats, _ = load_data(
|
||||||
|
config["dataset_dir"],
|
||||||
|
name_filter,
|
||||||
|
self._camera_names,
|
||||||
|
self._batch_size,
|
||||||
|
self._batch_size,
|
||||||
|
config["chunk_size"],
|
||||||
|
config["skip_mirrored_data"],
|
||||||
|
self._load_pretrain,
|
||||||
|
self._policy_class,
|
||||||
|
config["stats_dir"],
|
||||||
|
config["sample_weights"],
|
||||||
|
config["train_ratio"],
|
||||||
|
)
|
||||||
|
# 保存数据统计信息
|
||||||
|
stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl")
|
||||||
|
with open(stats_path, "wb") as f:
|
||||||
|
pickle.dump(self._stats, f)
|
||||||
|
expr_name = self._ckpt_dir.split("/")[-1]
|
||||||
|
|
||||||
|
# wandb.init(
|
||||||
|
# project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name
|
||||||
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def _make_policy(self):
|
||||||
|
"""
|
||||||
|
根据策略类和配置创建策略对象
|
||||||
|
"""
|
||||||
|
if self._policy_class == "ACT":
|
||||||
|
return ACTPolicy(self._policy_config)
|
||||||
|
elif self._policy_class == "CNNMLP":
|
||||||
|
return CNNMLPPolicy(self._policy_config)
|
||||||
|
elif self._policy_class == "Diffusion":
|
||||||
|
return DiffusionPolicy(self._policy_config)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Policy class {self._policy_class} is not implemented")
|
||||||
|
|
||||||
|
def _make_optimizer(self):
|
||||||
|
"""
|
||||||
|
根据策略类创建优化器
|
||||||
|
"""
|
||||||
|
if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]:
|
||||||
|
return self._policy.configure_optimizers()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _forward_pass(self, data):
|
||||||
|
"""
|
||||||
|
前向传播,计算损失
|
||||||
|
"""
|
||||||
|
image_data, qpos_data, action_data, is_pad = data
|
||||||
|
try:
|
||||||
|
image_data, qpos_data, action_data, is_pad = (
|
||||||
|
image_data.cuda(),
|
||||||
|
qpos_data.cuda(),
|
||||||
|
action_data.cuda(),
|
||||||
|
is_pad.cuda(),
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logging.error(f"CUDA error: {e}")
|
||||||
|
raise
|
||||||
|
return self._policy(qpos_data, image_data, action_data, is_pad)
|
||||||
|
|
||||||
|
def _repeater(self):
|
||||||
|
"""
|
||||||
|
数据加载器的重复器,生成数据
|
||||||
|
"""
|
||||||
|
epoch = 0
|
||||||
|
for loader in repeat(self._train_dataloader):
|
||||||
|
for data in loader:
|
||||||
|
yield data
|
||||||
|
logging.info(f"Epoch {epoch} done")
|
||||||
|
epoch += 1
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
"""
|
||||||
|
训练模型,保存最佳模型
|
||||||
|
"""
|
||||||
|
set_seed(self._seed)
|
||||||
|
self._policy = self._make_policy()
|
||||||
|
min_val_loss = np.inf
|
||||||
|
best_ckpt_info = None
|
||||||
|
|
||||||
|
# 加载预训练模型
|
||||||
|
if self._load_pretrain:
|
||||||
|
try:
|
||||||
|
loading_status = self._policy.deserialize(
|
||||||
|
torch.load(
|
||||||
|
os.path.join(
|
||||||
|
"/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all",
|
||||||
|
"policy_step_50000_seed_0.ckpt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logging.info(f"loaded! {loading_status}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logging.error(f"Pretrain model not found: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading pretrain model: {e}")
|
||||||
|
|
||||||
|
# 恢复检查点
|
||||||
|
if self._resume_ckpt_path is not None:
|
||||||
|
try:
|
||||||
|
loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path))
|
||||||
|
logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}")
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logging.error(f"Checkpoint not found: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error loading checkpoint: {e}")
|
||||||
|
|
||||||
|
self._policy.cuda()
|
||||||
|
|
||||||
|
self._optimizer = self._make_optimizer()
|
||||||
|
train_dataloader = self._repeater() # 重复器
|
||||||
|
|
||||||
|
for step in tqdm(range(self._num_steps + 1)):
|
||||||
|
# 验证模型
|
||||||
|
if step % self._validate_every != 0:
|
||||||
|
continue
|
||||||
|
logging.info("validating")
|
||||||
|
with torch.inference_mode():
|
||||||
|
self._policy.eval()
|
||||||
|
validation_dicts = []
|
||||||
|
for batch_idx, data in enumerate(self._val_dataloader):
|
||||||
|
forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse}
|
||||||
|
validation_dicts.append(forward_dict)
|
||||||
|
if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联
|
||||||
|
break
|
||||||
|
|
||||||
|
validation_summary = compute_dict_mean(validation_dicts)
|
||||||
|
epoch_val_loss = validation_summary["loss"]
|
||||||
|
if epoch_val_loss < min_val_loss:
|
||||||
|
min_val_loss = epoch_val_loss
|
||||||
|
best_ckpt_info = (
|
||||||
|
step,
|
||||||
|
min_val_loss,
|
||||||
|
deepcopy(self._policy.serialize()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# wandb记录验证结果
|
||||||
|
# for k in list(validation_summary.keys()):
|
||||||
|
# validation_summary[f"val_{k}"] = validation_summary.pop(k)
|
||||||
|
|
||||||
|
# wandb.log(validation_summary, step=step)
|
||||||
|
logging.info(f"Val loss: {epoch_val_loss:.5f}")
|
||||||
|
summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items())
|
||||||
|
logging.info(summary_string)
|
||||||
|
|
||||||
|
# 评估模型
|
||||||
|
# if (step > 0) and (step % self._eval_every == 0):
|
||||||
|
# ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt"
|
||||||
|
# ckpt_path = os.path.join(self._ckpt_dir, ckpt_name)
|
||||||
|
# torch.save(self._policy.serialize(), ckpt_path)
|
||||||
|
# success, _ = self._eval.evaluate(ckpt_name)
|
||||||
|
# wandb.log({"success": success}, step=step)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
self._policy.train()
|
||||||
|
self._optimizer.zero_grad()
|
||||||
|
data = next(train_dataloader)
|
||||||
|
forward_dict = self._forward_pass(data)
|
||||||
|
loss = forward_dict["loss"]
|
||||||
|
loss.backward()
|
||||||
|
self._optimizer.step()
|
||||||
|
# wandb.log(forward_dict, step=step)
|
||||||
|
|
||||||
|
# 保存检查点
|
||||||
|
if step % self._save_every == 0:
|
||||||
|
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt")
|
||||||
|
torch.save(self._policy.serialize(), ckpt_path)
|
||||||
|
|
||||||
|
# 保存最后的模型
|
||||||
|
ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt")
|
||||||
|
torch.save(self._policy.serialize(), ckpt_path)
|
||||||
|
|
||||||
|
best_step, min_val_loss, best_state_dict = best_ckpt_info
|
||||||
|
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt")
|
||||||
|
torch.save(best_state_dict, ckpt_path)
|
||||||
|
logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}")
|
||||||
|
|
||||||
|
return best_ckpt_info
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f:
|
||||||
|
config = yaml.load(f, Loader=yaml.FullLoader)
|
||||||
|
trainer = RmActTrainer(config)
|
||||||
|
trainer.train()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1,88 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Utilities for bounding box manipulation and GIoU.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from torchvision.ops.boxes import box_area
|
||||||
|
|
||||||
|
|
||||||
|
def box_cxcywh_to_xyxy(x):
|
||||||
|
x_c, y_c, w, h = x.unbind(-1)
|
||||||
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
||||||
|
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||||
|
return torch.stack(b, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def box_xyxy_to_cxcywh(x):
|
||||||
|
x0, y0, x1, y1 = x.unbind(-1)
|
||||||
|
b = [(x0 + x1) / 2, (y0 + y1) / 2,
|
||||||
|
(x1 - x0), (y1 - y0)]
|
||||||
|
return torch.stack(b, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# modified from torchvision to also return the union
|
||||||
|
def box_iou(boxes1, boxes2):
|
||||||
|
area1 = box_area(boxes1)
|
||||||
|
area2 = box_area(boxes2)
|
||||||
|
|
||||||
|
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
|
||||||
|
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
|
||||||
|
|
||||||
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||||
|
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
|
||||||
|
|
||||||
|
union = area1[:, None] + area2 - inter
|
||||||
|
|
||||||
|
iou = inter / union
|
||||||
|
return iou, union
|
||||||
|
|
||||||
|
|
||||||
|
def generalized_box_iou(boxes1, boxes2):
|
||||||
|
"""
|
||||||
|
Generalized IoU from https://giou.stanford.edu/
|
||||||
|
|
||||||
|
The boxes should be in [x0, y0, x1, y1] format
|
||||||
|
|
||||||
|
Returns a [N, M] pairwise matrix, where N = len(boxes1)
|
||||||
|
and M = len(boxes2)
|
||||||
|
"""
|
||||||
|
# degenerate boxes gives inf / nan results
|
||||||
|
# so do an early check
|
||||||
|
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
|
||||||
|
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
|
||||||
|
iou, union = box_iou(boxes1, boxes2)
|
||||||
|
|
||||||
|
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
|
||||||
|
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
|
||||||
|
|
||||||
|
wh = (rb - lt).clamp(min=0) # [N,M,2]
|
||||||
|
area = wh[:, :, 0] * wh[:, :, 1]
|
||||||
|
|
||||||
|
return iou - (area - union) / area
|
||||||
|
|
||||||
|
|
||||||
|
def masks_to_boxes(masks):
|
||||||
|
"""Compute the bounding boxes around the provided masks
|
||||||
|
|
||||||
|
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
|
||||||
|
|
||||||
|
Returns a [N, 4] tensors, with the boxes in xyxy format
|
||||||
|
"""
|
||||||
|
if masks.numel() == 0:
|
||||||
|
return torch.zeros((0, 4), device=masks.device)
|
||||||
|
|
||||||
|
h, w = masks.shape[-2:]
|
||||||
|
|
||||||
|
y = torch.arange(0, h, dtype=torch.float)
|
||||||
|
x = torch.arange(0, w, dtype=torch.float)
|
||||||
|
y, x = torch.meshgrid(y, x)
|
||||||
|
|
||||||
|
x_mask = (masks * x.unsqueeze(0))
|
||||||
|
x_max = x_mask.flatten(1).max(-1)[0]
|
||||||
|
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||||
|
|
||||||
|
y_mask = (masks * y.unsqueeze(0))
|
||||||
|
y_max = y_mask.flatten(1).max(-1)[0]
|
||||||
|
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
|
||||||
|
|
||||||
|
return torch.stack([x_min, y_min, x_max, y_max], 1)
|
||||||
@@ -0,0 +1,468 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||||
|
"""
|
||||||
|
Misc functions, including distributed helpers.
|
||||||
|
|
||||||
|
Mostly copy-paste from torchvision references.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
import datetime
|
||||||
|
import pickle
|
||||||
|
from packaging import version
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
||||||
|
import torchvision
|
||||||
|
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||||
|
from torchvision.ops import _new_empty_tensor
|
||||||
|
from torchvision.ops.misc import _output_size
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue(object):
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return
|
||||||
|
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median,
|
||||||
|
avg=self.avg,
|
||||||
|
global_avg=self.global_avg,
|
||||||
|
max=self.max,
|
||||||
|
value=self.value)
|
||||||
|
|
||||||
|
|
||||||
|
def all_gather(data):
|
||||||
|
"""
|
||||||
|
Run all_gather on arbitrary picklable data (not necessarily tensors)
|
||||||
|
Args:
|
||||||
|
data: any picklable object
|
||||||
|
Returns:
|
||||||
|
list[data]: list of data gathered from each rank
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size == 1:
|
||||||
|
return [data]
|
||||||
|
|
||||||
|
# serialized to a Tensor
|
||||||
|
buffer = pickle.dumps(data)
|
||||||
|
storage = torch.ByteStorage.from_buffer(buffer)
|
||||||
|
tensor = torch.ByteTensor(storage).to("cuda")
|
||||||
|
|
||||||
|
# obtain Tensor size of each rank
|
||||||
|
local_size = torch.tensor([tensor.numel()], device="cuda")
|
||||||
|
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
|
||||||
|
dist.all_gather(size_list, local_size)
|
||||||
|
size_list = [int(size.item()) for size in size_list]
|
||||||
|
max_size = max(size_list)
|
||||||
|
|
||||||
|
# receiving Tensor from all ranks
|
||||||
|
# we pad the tensor because torch all_gather does not support
|
||||||
|
# gathering tensors of different shapes
|
||||||
|
tensor_list = []
|
||||||
|
for _ in size_list:
|
||||||
|
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
||||||
|
if local_size != max_size:
|
||||||
|
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
|
||||||
|
tensor = torch.cat((tensor, padding), dim=0)
|
||||||
|
dist.all_gather(tensor_list, tensor)
|
||||||
|
|
||||||
|
data_list = []
|
||||||
|
for size, tensor in zip(size_list, tensor_list):
|
||||||
|
buffer = tensor.cpu().numpy().tobytes()[:size]
|
||||||
|
data_list.append(pickle.loads(buffer))
|
||||||
|
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_dict(input_dict, average=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_dict (dict): all the values will be reduced
|
||||||
|
average (bool): whether to do average or sum
|
||||||
|
Reduce the values in the dictionary from all processes so that all processes
|
||||||
|
have the averaged results. Returns a dict with the same fields as
|
||||||
|
input_dict, after reduction.
|
||||||
|
"""
|
||||||
|
world_size = get_world_size()
|
||||||
|
if world_size < 2:
|
||||||
|
return input_dict
|
||||||
|
with torch.no_grad():
|
||||||
|
names = []
|
||||||
|
values = []
|
||||||
|
# sort the keys so that they are consistent across processes
|
||||||
|
for k in sorted(input_dict.keys()):
|
||||||
|
names.append(k)
|
||||||
|
values.append(input_dict[k])
|
||||||
|
values = torch.stack(values, dim=0)
|
||||||
|
dist.all_reduce(values)
|
||||||
|
if average:
|
||||||
|
values /= world_size
|
||||||
|
reduced_dict = {k: v for k, v in zip(names, values)}
|
||||||
|
return reduced_dict
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger(object):
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
assert isinstance(v, (float, int))
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError("'{}' object has no attribute '{}'".format(
|
||||||
|
type(self).__name__, attr))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append(
|
||||||
|
"{}: {}".format(name, str(meter))
|
||||||
|
)
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ''
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
||||||
|
data_time = SmoothedValue(fmt='{avg:.4f}')
|
||||||
|
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join([
|
||||||
|
header,
|
||||||
|
'[{0' + space_fmt + '}/{1}]',
|
||||||
|
'eta: {eta}',
|
||||||
|
'{meters}',
|
||||||
|
'time: {time}',
|
||||||
|
'data: {data}',
|
||||||
|
'max mem: {memory:.0f}'
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join([
|
||||||
|
header,
|
||||||
|
'[{0' + space_fmt + '}/{1}]',
|
||||||
|
'eta: {eta}',
|
||||||
|
'{meters}',
|
||||||
|
'time: {time}',
|
||||||
|
'data: {data}'
|
||||||
|
])
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0 or i == len(iterable) - 1:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(log_msg.format(
|
||||||
|
i, len(iterable), eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time), data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB))
|
||||||
|
else:
|
||||||
|
print(log_msg.format(
|
||||||
|
i, len(iterable), eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time), data=str(data_time)))
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print('{} Total time: {} ({:.4f} s / it)'.format(
|
||||||
|
header, total_time_str, total_time / len(iterable)))
|
||||||
|
|
||||||
|
|
||||||
|
def get_sha():
|
||||||
|
cwd = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
def _run(command):
|
||||||
|
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
||||||
|
sha = 'N/A'
|
||||||
|
diff = "clean"
|
||||||
|
branch = 'N/A'
|
||||||
|
try:
|
||||||
|
sha = _run(['git', 'rev-parse', 'HEAD'])
|
||||||
|
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
||||||
|
diff = _run(['git', 'diff-index', 'HEAD'])
|
||||||
|
diff = "has uncommited changes" if diff else "clean"
|
||||||
|
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
batch = list(zip(*batch))
|
||||||
|
batch[0] = nested_tensor_from_tensor_list(batch[0])
|
||||||
|
return tuple(batch)
|
||||||
|
|
||||||
|
|
||||||
|
def _max_by_axis(the_list):
|
||||||
|
# type: (List[List[int]]) -> List[int]
|
||||||
|
maxes = the_list[0]
|
||||||
|
for sublist in the_list[1:]:
|
||||||
|
for index, item in enumerate(sublist):
|
||||||
|
maxes[index] = max(maxes[index], item)
|
||||||
|
return maxes
|
||||||
|
|
||||||
|
|
||||||
|
class NestedTensor(object):
|
||||||
|
def __init__(self, tensors, mask: Optional[Tensor]):
|
||||||
|
self.tensors = tensors
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
# type: (Device) -> NestedTensor # noqa
|
||||||
|
cast_tensor = self.tensors.to(device)
|
||||||
|
mask = self.mask
|
||||||
|
if mask is not None:
|
||||||
|
assert mask is not None
|
||||||
|
cast_mask = mask.to(device)
|
||||||
|
else:
|
||||||
|
cast_mask = None
|
||||||
|
return NestedTensor(cast_tensor, cast_mask)
|
||||||
|
|
||||||
|
def decompose(self):
|
||||||
|
return self.tensors, self.mask
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.tensors)
|
||||||
|
|
||||||
|
|
||||||
|
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
||||||
|
# TODO make this more general
|
||||||
|
if tensor_list[0].ndim == 3:
|
||||||
|
if torchvision._is_tracing():
|
||||||
|
# nested_tensor_from_tensor_list() does not export well to ONNX
|
||||||
|
# call _onnx_nested_tensor_from_tensor_list() instead
|
||||||
|
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
||||||
|
|
||||||
|
# TODO make it support different-sized images
|
||||||
|
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
||||||
|
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
||||||
|
batch_shape = [len(tensor_list)] + max_size
|
||||||
|
b, c, h, w = batch_shape
|
||||||
|
dtype = tensor_list[0].dtype
|
||||||
|
device = tensor_list[0].device
|
||||||
|
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
||||||
|
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
||||||
|
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
||||||
|
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||||
|
m[: img.shape[1], :img.shape[2]] = False
|
||||||
|
else:
|
||||||
|
raise ValueError('not supported')
|
||||||
|
return NestedTensor(tensor, mask)
|
||||||
|
|
||||||
|
|
||||||
|
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
||||||
|
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
||||||
|
@torch.jit.unused
|
||||||
|
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
||||||
|
max_size = []
|
||||||
|
for i in range(tensor_list[0].dim()):
|
||||||
|
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
||||||
|
max_size.append(max_size_i)
|
||||||
|
max_size = tuple(max_size)
|
||||||
|
|
||||||
|
# work around for
|
||||||
|
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
||||||
|
# m[: img.shape[1], :img.shape[2]] = False
|
||||||
|
# which is not yet supported in onnx
|
||||||
|
padded_imgs = []
|
||||||
|
padded_masks = []
|
||||||
|
for img in tensor_list:
|
||||||
|
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
||||||
|
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
||||||
|
padded_imgs.append(padded_img)
|
||||||
|
|
||||||
|
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
||||||
|
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
||||||
|
padded_masks.append(padded_mask.to(torch.bool))
|
||||||
|
|
||||||
|
tensor = torch.stack(padded_imgs)
|
||||||
|
mask = torch.stack(padded_masks)
|
||||||
|
|
||||||
|
return NestedTensor(tensor, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop('force', False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ['WORLD_SIZE'])
|
||||||
|
args.gpu = int(os.environ['LOCAL_RANK'])
|
||||||
|
elif 'SLURM_PROCID' in os.environ:
|
||||||
|
args.rank = int(os.environ['SLURM_PROCID'])
|
||||||
|
args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
else:
|
||||||
|
print('Not using distributed mode')
|
||||||
|
args.distributed = False
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = 'nccl'
|
||||||
|
print('| distributed init (rank {}): {}'.format(
|
||||||
|
args.rank, args.dist_url), flush=True)
|
||||||
|
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||||
|
world_size=args.world_size, rank=args.rank)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the precision@k for the specified values of k"""
|
||||||
|
if target.numel() == 0:
|
||||||
|
return [torch.zeros([], device=output.device)]
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].view(-1).float().sum(0)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
|
||||||
|
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
|
||||||
|
"""
|
||||||
|
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
||||||
|
This will eventually be supported natively by PyTorch, and this
|
||||||
|
class can go away.
|
||||||
|
"""
|
||||||
|
if version.parse(torchvision.__version__) < version.parse('0.7'):
|
||||||
|
if input.numel() > 0:
|
||||||
|
return torch.nn.functional.interpolate(
|
||||||
|
input, size, scale_factor, mode, align_corners
|
||||||
|
)
|
||||||
|
|
||||||
|
output_shape = _output_size(2, input, size, scale_factor)
|
||||||
|
output_shape = list(input.shape[:-2]) + list(output_shape)
|
||||||
|
return _new_empty_tensor(input, output_shape)
|
||||||
|
else:
|
||||||
|
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""
|
||||||
|
Plotting utilities to visualize training logs.
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import seaborn as sns
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from pathlib import Path, PurePath
|
||||||
|
|
||||||
|
|
||||||
|
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
|
||||||
|
'''
|
||||||
|
Function to plot specific fields from training log(s). Plots both training and test results.
|
||||||
|
|
||||||
|
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
|
||||||
|
- fields = which results to plot from each log file - plots both training and test for each field.
|
||||||
|
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
|
||||||
|
- log_name = optional, name of log file if different than default 'log.txt'.
|
||||||
|
|
||||||
|
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
|
||||||
|
- solid lines are training results, dashed lines are test results.
|
||||||
|
|
||||||
|
'''
|
||||||
|
func_name = "plot_utils.py::plot_logs"
|
||||||
|
|
||||||
|
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
|
||||||
|
# convert single Path to list to avoid 'not iterable' error
|
||||||
|
|
||||||
|
if not isinstance(logs, list):
|
||||||
|
if isinstance(logs, PurePath):
|
||||||
|
logs = [logs]
|
||||||
|
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
|
||||||
|
Expect list[Path] or single Path obj, received {type(logs)}")
|
||||||
|
|
||||||
|
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
|
||||||
|
for i, dir in enumerate(logs):
|
||||||
|
if not isinstance(dir, PurePath):
|
||||||
|
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
|
||||||
|
if not dir.exists():
|
||||||
|
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
|
||||||
|
# verify log_name exists
|
||||||
|
fn = Path(dir / log_name)
|
||||||
|
if not fn.exists():
|
||||||
|
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
|
||||||
|
print(f"--> full path of missing log file: {fn}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# load log file(s) and plot
|
||||||
|
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
|
||||||
|
|
||||||
|
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
|
||||||
|
for j, field in enumerate(fields):
|
||||||
|
if field == 'mAP':
|
||||||
|
coco_eval = pd.DataFrame(
|
||||||
|
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
|
||||||
|
).ewm(com=ewm_col).mean()
|
||||||
|
axs[j].plot(coco_eval, c=color)
|
||||||
|
else:
|
||||||
|
df.interpolate().ewm(com=ewm_col).mean().plot(
|
||||||
|
y=[f'train_{field}', f'test_{field}'],
|
||||||
|
ax=axs[j],
|
||||||
|
color=[color] * 2,
|
||||||
|
style=['-', '--']
|
||||||
|
)
|
||||||
|
for ax, field in zip(axs, fields):
|
||||||
|
ax.legend([Path(p).name for p in logs])
|
||||||
|
ax.set_title(field)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_precision_recall(files, naming_scheme='iter'):
|
||||||
|
if naming_scheme == 'exp_id':
|
||||||
|
# name becomes exp_id
|
||||||
|
names = [f.parts[-3] for f in files]
|
||||||
|
elif naming_scheme == 'iter':
|
||||||
|
names = [f.stem for f in files]
|
||||||
|
else:
|
||||||
|
raise ValueError(f'not supported {naming_scheme}')
|
||||||
|
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
|
||||||
|
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
|
||||||
|
data = torch.load(f)
|
||||||
|
# precision is n_iou, n_points, n_cat, n_area, max_det
|
||||||
|
precision = data['precision']
|
||||||
|
recall = data['params'].recThrs
|
||||||
|
scores = data['scores']
|
||||||
|
# take precision for all classes, all areas and 100 detections
|
||||||
|
precision = precision[0, :, :, 0, -1].mean(1)
|
||||||
|
scores = scores[0, :, :, 0, -1].mean(1)
|
||||||
|
prec = precision.mean()
|
||||||
|
rec = data['recall'][0, :, 0, -1].mean()
|
||||||
|
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
|
||||||
|
f'score={scores.mean():0.3f}, ' +
|
||||||
|
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
|
||||||
|
)
|
||||||
|
axs[0].plot(recall, precision, c=color)
|
||||||
|
axs[1].plot(recall, scores, c=color)
|
||||||
|
|
||||||
|
axs[0].set_title('Precision / Recall')
|
||||||
|
axs[0].legend(names)
|
||||||
|
axs[1].set_title('Scores / Recall')
|
||||||
|
axs[1].legend(names)
|
||||||
|
return fig, axs
|
||||||
@@ -0,0 +1,499 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
import pickle
|
||||||
|
import fnmatch
|
||||||
|
import cv2
|
||||||
|
from time import time
|
||||||
|
from torch.utils.data import TensorDataset, DataLoader
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_list(l):
|
||||||
|
return [item for sublist in l for item in sublist]
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodicDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_path_list,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
episode_ids,
|
||||||
|
episode_len,
|
||||||
|
chunk_size,
|
||||||
|
policy_class,
|
||||||
|
):
|
||||||
|
super(EpisodicDataset).__init__()
|
||||||
|
self.episode_ids = episode_ids
|
||||||
|
self.dataset_path_list = dataset_path_list
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.norm_stats = norm_stats
|
||||||
|
self.episode_len = episode_len
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self.cumulative_len = np.cumsum(self.episode_len)
|
||||||
|
self.max_episode_len = max(episode_len)
|
||||||
|
self.policy_class = policy_class
|
||||||
|
if self.policy_class == "Diffusion":
|
||||||
|
self.augment_images = True
|
||||||
|
else:
|
||||||
|
self.augment_images = False
|
||||||
|
self.transformations = None
|
||||||
|
self.__getitem__(0) # initialize self.is_sim and self.transformations
|
||||||
|
self.is_sim = False
|
||||||
|
|
||||||
|
# def __len__(self):
|
||||||
|
# return sum(self.episode_len)
|
||||||
|
|
||||||
|
def _locate_transition(self, index):
|
||||||
|
assert index < self.cumulative_len[-1]
|
||||||
|
episode_index = np.argmax(
|
||||||
|
self.cumulative_len > index
|
||||||
|
) # argmax returns first True index
|
||||||
|
start_ts = index - (
|
||||||
|
self.cumulative_len[episode_index] - self.episode_len[episode_index]
|
||||||
|
)
|
||||||
|
episode_id = self.episode_ids[episode_index]
|
||||||
|
return episode_id, start_ts
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
episode_id, start_ts = self._locate_transition(index)
|
||||||
|
dataset_path = self.dataset_path_list[episode_id]
|
||||||
|
try:
|
||||||
|
# print(dataset_path)
|
||||||
|
with h5py.File(dataset_path, "r") as root:
|
||||||
|
try: # some legacy data does not have this attribute
|
||||||
|
is_sim = root.attrs["sim"]
|
||||||
|
except:
|
||||||
|
is_sim = False
|
||||||
|
compressed = root.attrs.get("compress", False)
|
||||||
|
if "/base_action" in root:
|
||||||
|
base_action = root["/base_action"][()]
|
||||||
|
base_action = preprocess_base_action(base_action)
|
||||||
|
action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||||
|
else:
|
||||||
|
# TODO
|
||||||
|
action = root["/action"][()]
|
||||||
|
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||||
|
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||||
|
original_action_shape = action.shape
|
||||||
|
episode_len = original_action_shape[0]
|
||||||
|
# get observation at start_ts only
|
||||||
|
qpos = root["/observations/qpos"][start_ts]
|
||||||
|
qvel = root["/observations/qvel"][start_ts]
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][
|
||||||
|
start_ts
|
||||||
|
]
|
||||||
|
|
||||||
|
if compressed:
|
||||||
|
for cam_name in image_dict.keys():
|
||||||
|
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
|
||||||
|
image_dict[cam_name] = np.array(decompressed_image)
|
||||||
|
|
||||||
|
# get all actions after and including start_ts
|
||||||
|
if is_sim:
|
||||||
|
action = action[start_ts:]
|
||||||
|
action_len = episode_len - start_ts
|
||||||
|
else:
|
||||||
|
action = action[
|
||||||
|
max(0, start_ts - 1) :
|
||||||
|
] # hack, to make timesteps more aligned
|
||||||
|
action_len = episode_len - max(
|
||||||
|
0, start_ts - 1
|
||||||
|
) # hack, to make timesteps more aligned
|
||||||
|
|
||||||
|
# self.is_sim = is_sim
|
||||||
|
padded_action = np.zeros(
|
||||||
|
(self.max_episode_len, original_action_shape[1]), dtype=np.float32
|
||||||
|
)
|
||||||
|
padded_action[:action_len] = action
|
||||||
|
is_pad = np.zeros(self.max_episode_len)
|
||||||
|
is_pad[action_len:] = 1
|
||||||
|
|
||||||
|
padded_action = padded_action[: self.chunk_size]
|
||||||
|
is_pad = is_pad[: self.chunk_size]
|
||||||
|
|
||||||
|
# new axis for different cameras
|
||||||
|
all_cam_images = []
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
all_cam_images.append(image_dict[cam_name])
|
||||||
|
all_cam_images = np.stack(all_cam_images, axis=0)
|
||||||
|
|
||||||
|
# construct observations
|
||||||
|
image_data = torch.from_numpy(all_cam_images)
|
||||||
|
qpos_data = torch.from_numpy(qpos).float()
|
||||||
|
action_data = torch.from_numpy(padded_action).float()
|
||||||
|
is_pad = torch.from_numpy(is_pad).bool()
|
||||||
|
|
||||||
|
# channel last
|
||||||
|
image_data = torch.einsum("k h w c -> k c h w", image_data)
|
||||||
|
|
||||||
|
# augmentation
|
||||||
|
if self.transformations is None:
|
||||||
|
print("Initializing transformations")
|
||||||
|
original_size = image_data.shape[2:]
|
||||||
|
ratio = 0.95
|
||||||
|
self.transformations = [
|
||||||
|
transforms.RandomCrop(
|
||||||
|
size=[
|
||||||
|
int(original_size[0] * ratio),
|
||||||
|
int(original_size[1] * ratio),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
transforms.Resize(original_size, antialias=True),
|
||||||
|
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
|
||||||
|
transforms.ColorJitter(
|
||||||
|
brightness=0.3, contrast=0.4, saturation=0.5
|
||||||
|
), # , hue=0.08)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.augment_images:
|
||||||
|
for transform in self.transformations:
|
||||||
|
image_data = transform(image_data)
|
||||||
|
|
||||||
|
# normalize image and change dtype to float
|
||||||
|
image_data = image_data / 255.0
|
||||||
|
|
||||||
|
if self.policy_class == "Diffusion":
|
||||||
|
# normalize to [-1, 1]
|
||||||
|
action_data = (
|
||||||
|
(action_data - self.norm_stats["action_min"])
|
||||||
|
/ (self.norm_stats["action_max"] - self.norm_stats["action_min"])
|
||||||
|
) * 2 - 1
|
||||||
|
else:
|
||||||
|
# normalize to mean 0 std 1
|
||||||
|
action_data = (
|
||||||
|
action_data - self.norm_stats["action_mean"]
|
||||||
|
) / self.norm_stats["action_std"]
|
||||||
|
|
||||||
|
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
|
||||||
|
"qpos_std"
|
||||||
|
]
|
||||||
|
|
||||||
|
except:
|
||||||
|
print(f"Error loading {dataset_path} in __getitem__")
|
||||||
|
quit()
|
||||||
|
|
||||||
|
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
|
||||||
|
return image_data, qpos_data, action_data, is_pad
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_stats(dataset_path_list):
|
||||||
|
all_qpos_data = []
|
||||||
|
all_action_data = []
|
||||||
|
all_episode_len = []
|
||||||
|
|
||||||
|
for dataset_path in dataset_path_list:
|
||||||
|
try:
|
||||||
|
with h5py.File(dataset_path, "r") as root:
|
||||||
|
qpos = root["/observations/qpos"][()]
|
||||||
|
qvel = root["/observations/qvel"][()]
|
||||||
|
if "/base_action" in root:
|
||||||
|
base_action = root["/base_action"][()]
|
||||||
|
# base_action = preprocess_base_action(base_action)
|
||||||
|
# action = np.concatenate([root["/action"][()], base_action], axis=-1)
|
||||||
|
else:
|
||||||
|
# TODO
|
||||||
|
action = root["/action"][()]
|
||||||
|
# dummy_base_action = np.zeros([action.shape[0], 2])
|
||||||
|
# action = np.concatenate([action, dummy_base_action], axis=-1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading {dataset_path} in get_norm_stats")
|
||||||
|
print(e)
|
||||||
|
quit()
|
||||||
|
all_qpos_data.append(torch.from_numpy(qpos))
|
||||||
|
all_action_data.append(torch.from_numpy(action))
|
||||||
|
all_episode_len.append(len(qpos))
|
||||||
|
all_qpos_data = torch.cat(all_qpos_data, dim=0)
|
||||||
|
all_action_data = torch.cat(all_action_data, dim=0)
|
||||||
|
|
||||||
|
# normalize action data
|
||||||
|
action_mean = all_action_data.mean(dim=[0]).float()
|
||||||
|
action_std = all_action_data.std(dim=[0]).float()
|
||||||
|
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
|
||||||
|
|
||||||
|
# normalize qpos data
|
||||||
|
qpos_mean = all_qpos_data.mean(dim=[0]).float()
|
||||||
|
qpos_std = all_qpos_data.std(dim=[0]).float()
|
||||||
|
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
|
||||||
|
|
||||||
|
action_min = all_action_data.min(dim=0).values.float()
|
||||||
|
action_max = all_action_data.max(dim=0).values.float()
|
||||||
|
|
||||||
|
eps = 0.0001
|
||||||
|
stats = {
|
||||||
|
"action_mean": action_mean.numpy(),
|
||||||
|
"action_std": action_std.numpy(),
|
||||||
|
"action_min": action_min.numpy() - eps,
|
||||||
|
"action_max": action_max.numpy() + eps,
|
||||||
|
"qpos_mean": qpos_mean.numpy(),
|
||||||
|
"qpos_std": qpos_std.numpy(),
|
||||||
|
"example_qpos": qpos,
|
||||||
|
}
|
||||||
|
|
||||||
|
return stats, all_episode_len
|
||||||
|
|
||||||
|
|
||||||
|
def find_all_hdf5(dataset_dir, skip_mirrored_data):
|
||||||
|
hdf5_files = []
|
||||||
|
for root, dirs, files in os.walk(dataset_dir):
|
||||||
|
for filename in fnmatch.filter(files, "*.hdf5"):
|
||||||
|
if "features" in filename:
|
||||||
|
continue
|
||||||
|
if skip_mirrored_data and "mirror" in filename:
|
||||||
|
continue
|
||||||
|
hdf5_files.append(os.path.join(root, filename))
|
||||||
|
print(f"Found {len(hdf5_files)} hdf5 files")
|
||||||
|
return hdf5_files
|
||||||
|
|
||||||
|
|
||||||
|
def BatchSampler(batch_size, episode_len_l, sample_weights):
|
||||||
|
sample_probs = (
|
||||||
|
np.array(sample_weights) / np.sum(sample_weights)
|
||||||
|
if sample_weights is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
# print("BatchSampler", sample_weights)
|
||||||
|
sum_dataset_len_l = np.cumsum(
|
||||||
|
[0] + [np.sum(episode_len) for episode_len in episode_len_l]
|
||||||
|
)
|
||||||
|
while True:
|
||||||
|
batch = []
|
||||||
|
for _ in range(batch_size):
|
||||||
|
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
|
||||||
|
step_idx = np.random.randint(
|
||||||
|
sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1]
|
||||||
|
)
|
||||||
|
batch.append(step_idx)
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
dataset_dir_l,
|
||||||
|
name_filter,
|
||||||
|
camera_names,
|
||||||
|
batch_size_train,
|
||||||
|
batch_size_val,
|
||||||
|
chunk_size,
|
||||||
|
skip_mirrored_data=False,
|
||||||
|
load_pretrain=False,
|
||||||
|
policy_class=None,
|
||||||
|
stats_dir_l=None,
|
||||||
|
sample_weights=None,
|
||||||
|
train_ratio=0.99,
|
||||||
|
):
|
||||||
|
if type(dataset_dir_l) == str:
|
||||||
|
dataset_dir_l = [dataset_dir_l]
|
||||||
|
dataset_path_list_list = [
|
||||||
|
find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l
|
||||||
|
]
|
||||||
|
num_episodes_0 = len(dataset_path_list_list[0])
|
||||||
|
dataset_path_list = flatten_list(dataset_path_list_list)
|
||||||
|
|
||||||
|
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
|
||||||
|
num_episodes_l = [
|
||||||
|
len(dataset_path_list) for dataset_path_list in dataset_path_list_list
|
||||||
|
]
|
||||||
|
num_episodes_cumsum = np.cumsum(num_episodes_l)
|
||||||
|
|
||||||
|
# obtain train test split on dataset_dir_l[0]
|
||||||
|
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
|
||||||
|
train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)]
|
||||||
|
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :]
|
||||||
|
train_episode_ids_l = [train_episode_ids_0] + [
|
||||||
|
np.arange(num_episodes) + num_episodes_cumsum[idx]
|
||||||
|
for idx, num_episodes in enumerate(num_episodes_l[1:])
|
||||||
|
]
|
||||||
|
val_episode_ids_l = [val_episode_ids_0]
|
||||||
|
train_episode_ids = np.concatenate(train_episode_ids_l)
|
||||||
|
val_episode_ids = np.concatenate(val_episode_ids_l)
|
||||||
|
print(
|
||||||
|
f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
# obtain normalization stats for qpos and action
|
||||||
|
# if load_pretrain:
|
||||||
|
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
|
||||||
|
# norm_stats = pickle.load(f)
|
||||||
|
# print('Loaded pretrain dataset stats')
|
||||||
|
_, all_episode_len = get_norm_stats(dataset_path_list)
|
||||||
|
train_episode_len_l = [
|
||||||
|
[all_episode_len[i] for i in train_episode_ids]
|
||||||
|
for train_episode_ids in train_episode_ids_l
|
||||||
|
]
|
||||||
|
val_episode_len_l = [
|
||||||
|
[all_episode_len[i] for i in val_episode_ids]
|
||||||
|
for val_episode_ids in val_episode_ids_l
|
||||||
|
]
|
||||||
|
|
||||||
|
train_episode_len = flatten_list(train_episode_len_l)
|
||||||
|
val_episode_len = flatten_list(val_episode_len_l)
|
||||||
|
if stats_dir_l is None:
|
||||||
|
stats_dir_l = dataset_dir_l
|
||||||
|
elif type(stats_dir_l) == str:
|
||||||
|
stats_dir_l = [stats_dir_l]
|
||||||
|
norm_stats, _ = get_norm_stats(
|
||||||
|
flatten_list(
|
||||||
|
[find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(f"Norm stats from: {stats_dir_l}")
|
||||||
|
|
||||||
|
batch_sampler_train = BatchSampler(
|
||||||
|
batch_size_train, train_episode_len_l, sample_weights
|
||||||
|
)
|
||||||
|
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
|
||||||
|
|
||||||
|
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
|
||||||
|
|
||||||
|
# construct dataset and dataloader
|
||||||
|
train_dataset = EpisodicDataset(
|
||||||
|
dataset_path_list,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
train_episode_ids,
|
||||||
|
train_episode_len,
|
||||||
|
chunk_size,
|
||||||
|
policy_class,
|
||||||
|
)
|
||||||
|
val_dataset = EpisodicDataset(
|
||||||
|
dataset_path_list,
|
||||||
|
camera_names,
|
||||||
|
norm_stats,
|
||||||
|
val_episode_ids,
|
||||||
|
val_episode_len,
|
||||||
|
chunk_size,
|
||||||
|
policy_class,
|
||||||
|
)
|
||||||
|
train_num_workers = (
|
||||||
|
(8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2
|
||||||
|
)
|
||||||
|
val_num_workers = 8 if train_dataset.augment_images else 2
|
||||||
|
print(
|
||||||
|
f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}"
|
||||||
|
)
|
||||||
|
train_dataloader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_sampler=batch_sampler_train,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=train_num_workers,
|
||||||
|
prefetch_factor=2,
|
||||||
|
)
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_sampler=batch_sampler_val,
|
||||||
|
pin_memory=True,
|
||||||
|
num_workers=val_num_workers,
|
||||||
|
prefetch_factor=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
|
||||||
|
|
||||||
|
|
||||||
|
def calibrate_linear_vel(base_action, c=None):
|
||||||
|
if c is None:
|
||||||
|
c = 0.0 # 0.19
|
||||||
|
v = base_action[..., 0]
|
||||||
|
w = base_action[..., 1]
|
||||||
|
base_action = base_action.copy()
|
||||||
|
base_action[..., 0] = v - c * w
|
||||||
|
return base_action
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_base_action(base_action):
|
||||||
|
return np.stack(
|
||||||
|
[
|
||||||
|
np.convolve(base_action[:, i], np.ones(5) / 5, mode="same")
|
||||||
|
for i in range(base_action.shape[1])
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_base_action(base_action):
|
||||||
|
# base_action = calibrate_linear_vel(base_action)
|
||||||
|
base_action = smooth_base_action(base_action)
|
||||||
|
|
||||||
|
return base_action
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_base_action(base_action):
|
||||||
|
linear_vel, angular_vel = base_action
|
||||||
|
linear_vel *= 1.0
|
||||||
|
angular_vel *= 1.0
|
||||||
|
# angular_vel = 0
|
||||||
|
# if np.abs(linear_vel) < 0.05:
|
||||||
|
# linear_vel = 0
|
||||||
|
return np.array([linear_vel, angular_vel])
|
||||||
|
|
||||||
|
|
||||||
|
### env utils
|
||||||
|
|
||||||
|
|
||||||
|
def sample_box_pose():
|
||||||
|
x_range = [0.0, 0.2]
|
||||||
|
y_range = [0.4, 0.6]
|
||||||
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
|
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
|
cube_quat = np.array([1, 0, 0, 0])
|
||||||
|
return np.concatenate([cube_position, cube_quat])
|
||||||
|
|
||||||
|
|
||||||
|
def sample_insertion_pose():
|
||||||
|
# Peg
|
||||||
|
x_range = [0.1, 0.2]
|
||||||
|
y_range = [0.4, 0.6]
|
||||||
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
|
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
|
peg_quat = np.array([1, 0, 0, 0])
|
||||||
|
peg_pose = np.concatenate([peg_position, peg_quat])
|
||||||
|
|
||||||
|
# Socket
|
||||||
|
x_range = [-0.2, -0.1]
|
||||||
|
y_range = [0.4, 0.6]
|
||||||
|
z_range = [0.05, 0.05]
|
||||||
|
|
||||||
|
ranges = np.vstack([x_range, y_range, z_range])
|
||||||
|
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
|
||||||
|
|
||||||
|
socket_quat = np.array([1, 0, 0, 0])
|
||||||
|
socket_pose = np.concatenate([socket_position, socket_quat])
|
||||||
|
|
||||||
|
return peg_pose, socket_pose
|
||||||
|
|
||||||
|
|
||||||
|
### helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def compute_dict_mean(epoch_dicts):
|
||||||
|
result = {k: None for k in epoch_dicts[0]}
|
||||||
|
num_items = len(epoch_dicts)
|
||||||
|
for k in result:
|
||||||
|
value_sum = 0
|
||||||
|
for epoch_dict in epoch_dicts:
|
||||||
|
value_sum += epoch_dict[k]
|
||||||
|
result[k] = value_sum / num_items
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def detach_dict(d):
|
||||||
|
new_d = dict()
|
||||||
|
for k, v in d.items():
|
||||||
|
new_d[k] = v.detach()
|
||||||
|
return new_d
|
||||||
|
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
163
realman_src/realman_aloha/shadow_rm_act/test/test_camera.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
from shadow_camera.realsense import RealSenseCamera
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
import yaml
|
||||||
|
import time
|
||||||
|
import multiprocessing
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
import logging
|
||||||
|
import dm_env
|
||||||
|
import tracemalloc
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceAloha:
|
||||||
|
def __init__(self, aloha_config):
|
||||||
|
"""
|
||||||
|
初始化设备
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device_name (str): 设备名称
|
||||||
|
"""
|
||||||
|
config_left_arm = aloha_config["rm_left_arm"]
|
||||||
|
config_right_arm = aloha_config["rm_right_arm"]
|
||||||
|
config_head_camera = aloha_config["head_camera"]
|
||||||
|
config_bottom_camera = aloha_config["bottom_camera"]
|
||||||
|
config_left_camera = aloha_config["left_camera"]
|
||||||
|
config_right_camera = aloha_config["right_camera"]
|
||||||
|
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
|
||||||
|
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
|
||||||
|
self.arm_left = RmArm(config_left_arm)
|
||||||
|
self.arm_right = RmArm(config_right_arm)
|
||||||
|
self.camera_left = RealSenseCamera(config_head_camera, False)
|
||||||
|
self.camera_right = RealSenseCamera(config_bottom_camera, False)
|
||||||
|
self.camera_bottom = RealSenseCamera(config_left_camera, False)
|
||||||
|
self.camera_top = RealSenseCamera(config_right_camera, False)
|
||||||
|
self.camera_left.start_camera()
|
||||||
|
self.camera_right.start_camera()
|
||||||
|
self.camera_bottom.start_camera()
|
||||||
|
self.camera_top.start_camera()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""
|
||||||
|
关闭摄像头
|
||||||
|
"""
|
||||||
|
self.camera_left.close()
|
||||||
|
self.camera_right.close()
|
||||||
|
self.camera_bottom.close()
|
||||||
|
self.camera_top.close()
|
||||||
|
|
||||||
|
def get_qps(self):
|
||||||
|
"""
|
||||||
|
获取关节角度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 关节角度
|
||||||
|
"""
|
||||||
|
left_slave_arm_angle = self.arm_left.get_joint_angle()
|
||||||
|
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
|
||||||
|
right_slave_arm_angle = self.arm_right.get_joint_angle()
|
||||||
|
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
|
||||||
|
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
|
||||||
|
|
||||||
|
def get_qvel(self):
|
||||||
|
"""
|
||||||
|
获取关节速度
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 关节速度
|
||||||
|
"""
|
||||||
|
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
|
||||||
|
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
|
||||||
|
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
|
||||||
|
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
|
||||||
|
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
|
||||||
|
|
||||||
|
def get_effort(self):
|
||||||
|
"""
|
||||||
|
获取关节力
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.array: 关节力
|
||||||
|
"""
|
||||||
|
left_slave_arm_effort = self.arm_left.get_joint_effort()
|
||||||
|
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
|
||||||
|
right_slave_arm_effort = self.arm_right.get_joint_effort()
|
||||||
|
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
|
||||||
|
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
|
||||||
|
|
||||||
|
def get_images(self):
|
||||||
|
"""
|
||||||
|
获取图像
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 图像字典
|
||||||
|
"""
|
||||||
|
top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
|
||||||
|
bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False)
|
||||||
|
left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False)
|
||||||
|
right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False)
|
||||||
|
return {
|
||||||
|
"cam_high": top_image,
|
||||||
|
"cam_low": bottom_image,
|
||||||
|
"cam_left": left_image,
|
||||||
|
"cam_right": right_image,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_observation(self):
|
||||||
|
obs = collections.OrderedDict()
|
||||||
|
obs["qpos"] = self.get_qps()
|
||||||
|
obs["qvel"] = self.get_qvel()
|
||||||
|
obs["effort"] = self.get_effort()
|
||||||
|
obs["images"] = self.get_images()
|
||||||
|
# self.get_images()
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
logging.info("Resetting the environment")
|
||||||
|
_ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6])
|
||||||
|
_ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6])
|
||||||
|
self.arm_left.set_gripper_position(0)
|
||||||
|
self.arm_right.set_gripper_position(0)
|
||||||
|
return dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.FIRST,
|
||||||
|
reward=0,
|
||||||
|
discount=None,
|
||||||
|
observation=self.get_observation(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def step(self, target_angle):
|
||||||
|
self.arm_left.set_joint_canfd_position(target_angle[0:6])
|
||||||
|
self.arm_right.set_joint_canfd_position(target_angle[7:13])
|
||||||
|
self.arm_left.set_gripper_position(target_angle[6])
|
||||||
|
self.arm_right.set_gripper_position(target_angle[13])
|
||||||
|
return dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.MID,
|
||||||
|
reward=0,
|
||||||
|
discount=None,
|
||||||
|
observation=self.get_observation(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
aloha_config = config["robot_env"]
|
||||||
|
device = DeviceAloha(aloha_config)
|
||||||
|
device.reset()
|
||||||
|
image_list = []
|
||||||
|
tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
|
||||||
|
while True:
|
||||||
|
tracemalloc.start() # 启动内存跟踪
|
||||||
|
|
||||||
|
tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)])
|
||||||
|
time_step = time.time()
|
||||||
|
timestep = device.step(tager_angle)
|
||||||
|
logging.info(f"Time: {time.time() - time_step}")
|
||||||
|
image_list.append(timestep.observation["images"])
|
||||||
|
snapshot = tracemalloc.take_snapshot()
|
||||||
|
top_stats = snapshot.statistics('lineno')
|
||||||
|
# del timestep
|
||||||
|
print("[ Top 10 ]")
|
||||||
|
for stat in top_stats[:10]:
|
||||||
|
print(stat)
|
||||||
|
# logging.info(f"Images: {obs}")
|
||||||
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
32
realman_src/realman_aloha/shadow_rm_act/test/test_h5.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import os
|
||||||
|
# import time
|
||||||
|
import yaml
|
||||||
|
import torch
|
||||||
|
import pickle
|
||||||
|
import dm_env
|
||||||
|
import logging
|
||||||
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torchvision import transforms
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
from shadow_camera.realsense import RealSenseCamera
|
||||||
|
from shadow_act.models.latent_model import Latent_Model_Transformer
|
||||||
|
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
|
||||||
|
from shadow_act.utils.utils import (
|
||||||
|
load_data,
|
||||||
|
sample_box_pose,
|
||||||
|
sample_insertion_pose,
|
||||||
|
compute_dict_mean,
|
||||||
|
set_seed,
|
||||||
|
detach_dict,
|
||||||
|
)
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
print('daasdas')
|
||||||
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
147
realman_src/realman_aloha/shadow_rm_act/visualize_episodes.py
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import h5py
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from constants import DT
|
||||||
|
|
||||||
|
import IPython
|
||||||
|
e = IPython.embed
|
||||||
|
|
||||||
|
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||||
|
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||||
|
|
||||||
|
def load_hdf5(dataset_dir, dataset_name):
|
||||||
|
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
|
is_sim = root.attrs['sim']
|
||||||
|
qpos = root['/observations/qpos'][()]
|
||||||
|
qvel = root['/observations/qvel'][()]
|
||||||
|
action = root['/action'][()]
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in root[f'/observations/images/'].keys():
|
||||||
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||||
|
|
||||||
|
return qpos, qvel, action, image_dict
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
dataset_dir = args['dataset_dir']
|
||||||
|
episode_idx = args['episode_idx']
|
||||||
|
dataset_name = f'episode_{episode_idx}'
|
||||||
|
|
||||||
|
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
|
||||||
|
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||||
|
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||||
|
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
|
||||||
|
|
||||||
|
|
||||||
|
def save_videos(video, dt, video_path=None):
|
||||||
|
if isinstance(video, list):
|
||||||
|
cam_names = list(video[0].keys())
|
||||||
|
h, w, _ = video[0][cam_names[0]].shape
|
||||||
|
w = w * len(cam_names)
|
||||||
|
fps = int(1/dt)
|
||||||
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
for ts, image_dict in enumerate(video):
|
||||||
|
images = []
|
||||||
|
for cam_name in cam_names:
|
||||||
|
image = image_dict[cam_name]
|
||||||
|
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||||
|
images.append(image)
|
||||||
|
images = np.concatenate(images, axis=1)
|
||||||
|
out.write(images)
|
||||||
|
out.release()
|
||||||
|
print(f'Saved video to: {video_path}')
|
||||||
|
elif isinstance(video, dict):
|
||||||
|
cam_names = list(video.keys())
|
||||||
|
all_cam_videos = []
|
||||||
|
for cam_name in cam_names:
|
||||||
|
all_cam_videos.append(video[cam_name])
|
||||||
|
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||||
|
|
||||||
|
n_frames, h, w, _ = all_cam_videos.shape
|
||||||
|
fps = int(1 / dt)
|
||||||
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
for t in range(n_frames):
|
||||||
|
image = all_cam_videos[t]
|
||||||
|
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||||
|
out.write(image)
|
||||||
|
out.release()
|
||||||
|
print(f'Saved video to: {video_path}')
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||||
|
if label_overwrite:
|
||||||
|
label1, label2 = label_overwrite
|
||||||
|
else:
|
||||||
|
label1, label2 = 'State', 'Command'
|
||||||
|
|
||||||
|
qpos = np.array(qpos_list) # ts, dim
|
||||||
|
command = np.array(command_list)
|
||||||
|
num_ts, num_dim = qpos.shape
|
||||||
|
h, w = 2, num_dim
|
||||||
|
num_figs = num_dim
|
||||||
|
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
|
||||||
|
|
||||||
|
# plot joint state
|
||||||
|
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.plot(qpos[:, dim_idx], label=label1)
|
||||||
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
# plot arm command
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.plot(command[:, dim_idx], label=label2)
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
if ylim:
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.set_ylim(ylim)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
print(f'Saved qpos plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def visualize_timestamp(t_list, dataset_path):
|
||||||
|
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
|
||||||
|
h, w = 4, 10
|
||||||
|
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
|
||||||
|
# process t_list
|
||||||
|
t_float = []
|
||||||
|
for secs, nsecs in t_list:
|
||||||
|
t_float.append(secs + nsecs * 10E-10)
|
||||||
|
t_float = np.array(t_float)
|
||||||
|
|
||||||
|
ax = axs[0]
|
||||||
|
ax.plot(np.arange(len(t_float)), t_float)
|
||||||
|
ax.set_title(f'Camera frame timestamps')
|
||||||
|
ax.set_xlabel('timestep')
|
||||||
|
ax.set_ylabel('time (sec)')
|
||||||
|
|
||||||
|
ax = axs[1]
|
||||||
|
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
|
||||||
|
ax.set_title(f'dt')
|
||||||
|
ax.set_xlabel('timestep')
|
||||||
|
ax.set_ylabel('time (sec)')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
print(f'Saved timestamp plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||||
|
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
|
||||||
|
main(vars(parser.parse_args()))
|
||||||
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
__pycache__/
|
||||||
|
build/
|
||||||
|
devel/
|
||||||
|
dist/
|
||||||
|
data/
|
||||||
|
.catkin_workspace
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pt
|
||||||
|
.vscode/
|
||||||
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
@@ -0,0 +1 @@
|
|||||||
|
aloha_data_synchronizer.py
|
||||||
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="4">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="thop" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="torch" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="torchvision" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.11 (随箱软件)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="inheritedJdk" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
|
||||||
|
dataset_name: 'episode'
|
||||||
|
max_timesteps: 500
|
||||||
|
state_dim: 14
|
||||||
|
overwrite: False
|
||||||
|
arm_axis: 6
|
||||||
|
camera_names:
|
||||||
|
- 'cam_high'
|
||||||
|
- 'cam_low'
|
||||||
|
- 'cam_left'
|
||||||
|
- 'cam_right'
|
||||||
|
ros_topics:
|
||||||
|
camera_left: '/camera_left/rgb/image_raw'
|
||||||
|
camera_right: '/camera_right/rgb/image_raw'
|
||||||
|
camera_bottom: '/camera_bottom/rgb/image_raw'
|
||||||
|
camera_head: '/camera_head/rgb/image_raw'
|
||||||
|
left_master_arm: '/left_master_arm_joint_states'
|
||||||
|
left_slave_arm: '/left_slave_arm_joint_states'
|
||||||
|
right_master_arm: '/right_master_arm_joint_states'
|
||||||
|
right_slave_arm: '/right_slave_arm_joint_states'
|
||||||
|
left_aloha_state: '/left_slave_arm_aloha_state'
|
||||||
|
right_aloha_state: '/right_slave_arm_aloha_state'
|
||||||
|
robot_env: {
|
||||||
|
# TODO change the path to the correct one
|
||||||
|
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||||
|
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||||
|
arm_axis: 6,
|
||||||
|
head_camera: '241122071186',
|
||||||
|
bottom_camera: '152122078546',
|
||||||
|
left_camera: '150622070125',
|
||||||
|
right_camera: '151222072576',
|
||||||
|
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
|
||||||
|
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
|
||||||
|
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
|
||||||
|
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
|
||||||
|
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||||
|
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||||
|
}
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
arm_ip: "192.168.1.18"
|
||||||
|
arm_port: 8080
|
||||||
|
arm_axis: 6
|
||||||
|
local_ip: "192.168.1.101"
|
||||||
|
local_port: 8089
|
||||||
|
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||||
|
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||||
|
get_vel: True
|
||||||
|
get_torque: True
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
arm_ip: "192.168.1.19"
|
||||||
|
arm_port: 8080
|
||||||
|
arm_axis: 6
|
||||||
|
local_ip: "192.168.1.101"
|
||||||
|
local_port: 8090
|
||||||
|
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||||
|
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||||
|
get_vel: True
|
||||||
|
get_torque: True
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
port: /dev/ttyUSB1
|
||||||
|
baudrate: 460800
|
||||||
|
hex_data: "55 AA 02 00 00 67"
|
||||||
|
arm_axis: 6
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
port: /dev/ttyUSB0
|
||||||
|
baudrate: 460800
|
||||||
|
hex_data: "55 AA 02 00 00 67"
|
||||||
|
arm_axis: 6
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
|
||||||
|
dataset_name: 'episode'
|
||||||
|
episode_idx: 1
|
||||||
|
FPS: 30
|
||||||
|
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
|
||||||
|
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints
|
||||||
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
[tool.poetry]
|
||||||
|
name = "shadow_rm_aloha"
|
||||||
|
version = "0.1.1"
|
||||||
|
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
|
||||||
|
readme = "README.md"
|
||||||
|
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||||
|
license = "MIT"
|
||||||
|
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||||
|
classifiers = [
|
||||||
|
"Operating System :: POSIX :: Linux amd64",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.poetry.dependencies]
|
||||||
|
python = ">=3.10"
|
||||||
|
matplotlib = ">=3.9.2"
|
||||||
|
h5py = ">=3.12.1"
|
||||||
|
# rospy = ">=1.17.0"
|
||||||
|
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
|
||||||
|
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||||
|
pytest = ">=8.3"
|
||||||
|
black = ">=24.10.0"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["poetry-core>=1.8.4"]
|
||||||
|
build-backend = "poetry.core.masonry.api"
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
cmake_minimum_required(VERSION 3.0.2)
|
||||||
|
project(shadow_rm_aloha)
|
||||||
|
|
||||||
|
find_package(catkin REQUIRED COMPONENTS
|
||||||
|
rospy
|
||||||
|
sensor_msgs
|
||||||
|
cv_bridge
|
||||||
|
image_transport
|
||||||
|
std_msgs
|
||||||
|
message_generation
|
||||||
|
)
|
||||||
|
|
||||||
|
add_service_files(
|
||||||
|
FILES
|
||||||
|
GetArmStatus.srv
|
||||||
|
GetImage.srv
|
||||||
|
MoveArm.srv
|
||||||
|
)
|
||||||
|
|
||||||
|
generate_messages(
|
||||||
|
DEPENDENCIES
|
||||||
|
sensor_msgs
|
||||||
|
std_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
catkin_package(
|
||||||
|
CATKIN_DEPENDS message_runtime rospy std_msgs
|
||||||
|
)
|
||||||
|
|
||||||
|
include_directories(
|
||||||
|
${catkin_INCLUDE_DIRS}
|
||||||
|
)
|
||||||
|
|
||||||
|
install(PROGRAMS
|
||||||
|
arm_node/slave_arm_publisher.py
|
||||||
|
arm_node/master_arm_publisher.py
|
||||||
|
arm_node/slave_arm_pub_sub.py
|
||||||
|
camera_node/camera_publisher.py
|
||||||
|
data_sub_process/aloha_data_synchronizer.py
|
||||||
|
data_sub_process/aloha_data_collect.py
|
||||||
|
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||||
|
)
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
import logging
|
||||||
|
from shadow_rm_robot.servo_robotic_arm import ServoArm
|
||||||
|
from sensor_msgs.msg import JointState
|
||||||
|
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
class MasterArmPublisher:
|
||||||
|
def __init__(self):
|
||||||
|
rospy.init_node("master_arm_publisher", anonymous=True)
|
||||||
|
arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml")
|
||||||
|
hz = rospy.get_param("~hz", 250)
|
||||||
|
self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||||
|
|
||||||
|
self.arm = ServoArm(arm_config)
|
||||||
|
self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1)
|
||||||
|
self.rate = rospy.Rate(hz) # 30 Hz
|
||||||
|
|
||||||
|
def publish_joint_states(self):
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
joint_state = JointState()
|
||||||
|
joint_pos = self.arm.get_joint_actions()
|
||||||
|
|
||||||
|
joint_state.header.stamp = rospy.Time.now()
|
||||||
|
joint_state.name = list(joint_pos.keys())
|
||||||
|
joint_state.position = list(joint_pos.values())
|
||||||
|
joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改)
|
||||||
|
joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改)
|
||||||
|
|
||||||
|
# rospy.loginfo(f"{self.joint_states_topic}: {joint_state}")
|
||||||
|
self.publisher.publish(joint_state)
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
arm_publisher = MasterArmPublisher()
|
||||||
|
arm_publisher.publish_joint_states()
|
||||||
|
except rospy.ROSInterruptException:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
import logging
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
from sensor_msgs.msg import JointState
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
class SlaveArmPublisher:
|
||||||
|
def __init__(self):
|
||||||
|
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||||
|
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||||
|
hz = rospy.get_param("~hz", 250)
|
||||||
|
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||||
|
joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions")
|
||||||
|
self.arm = RmArm(arm_config)
|
||||||
|
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||||
|
self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback)
|
||||||
|
self.rate = rospy.Rate(hz)
|
||||||
|
|
||||||
|
def publish_joint_states(self):
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
joint_state = JointState()
|
||||||
|
data = self.arm.get_integrate_data()
|
||||||
|
# data = self.arm.get_arm_data()
|
||||||
|
joint_state.header.stamp = rospy.Time.now()
|
||||||
|
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
|
||||||
|
# joint_state.position = data["joint_angle"]
|
||||||
|
joint_state.position = data['arm_angle']
|
||||||
|
|
||||||
|
# joint_state.position = list(data["arm_angle"])
|
||||||
|
# joint_state.velocity = list(data["arm_velocity"])
|
||||||
|
# joint_state.effort = list(data["arm_torque"])
|
||||||
|
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||||
|
self.publisher.publish(joint_state)
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
def callback(self, data):
|
||||||
|
# rospy.loginfo(f"Received joint_states_topic: {data}")
|
||||||
|
start_time = rospy.Time.now()
|
||||||
|
if data is None:
|
||||||
|
return
|
||||||
|
if data.name == ["joint_canfd"]:
|
||||||
|
self.arm.set_joint_canfd_position(data.position[0:6])
|
||||||
|
elif data.name == ["joint_j"]:
|
||||||
|
self.arm.set_joint_position(data.position[0:6])
|
||||||
|
|
||||||
|
# self.arm.set_gripper_position(data.position[6])
|
||||||
|
end_time = rospy.Time.now()
|
||||||
|
time_cost_ms = (end_time - start_time).to_sec() * 1000
|
||||||
|
rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
arm_publisher = SlaveArmPublisher()
|
||||||
|
arm_publisher.publish_joint_states()
|
||||||
|
except rospy.ROSInterruptException:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
import logging
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
from sensor_msgs.msg import JointState
|
||||||
|
from std_msgs.msg import Int32MultiArray
|
||||||
|
# 配置日志记录
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
class SlaveArmPublisher:
|
||||||
|
def __init__(self):
|
||||||
|
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||||
|
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||||
|
hz = rospy.get_param("~hz", 250)
|
||||||
|
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||||
|
aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state")
|
||||||
|
self.arm = RmArm(arm_config)
|
||||||
|
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||||
|
self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1)
|
||||||
|
self.rate = rospy.Rate(hz)
|
||||||
|
|
||||||
|
def publish_joint_states(self):
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
joint_state = JointState()
|
||||||
|
data = self.arm.get_integrate_data()
|
||||||
|
joint_state.header.stamp = rospy.Time.now()
|
||||||
|
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
|
||||||
|
joint_state.position = list(data["arm_angle"])
|
||||||
|
joint_state.velocity = list(data["arm_velocity"])
|
||||||
|
joint_state.effort = list(data["arm_torque"])
|
||||||
|
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||||
|
self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values()))
|
||||||
|
self.publisher.publish(joint_state)
|
||||||
|
self.rate.sleep()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
arm_publisher = SlaveArmPublisher()
|
||||||
|
arm_publisher.publish_joint_states()
|
||||||
|
except rospy.ROSInterruptException:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from std_msgs.msg import Header
|
||||||
|
import numpy as np
|
||||||
|
from shadow_camera.realsense import RealSenseCamera
|
||||||
|
|
||||||
|
class CameraPublisher:
|
||||||
|
def __init__(self):
|
||||||
|
rospy.init_node('camera_publisher', anonymous=True)
|
||||||
|
self.serial_number = rospy.get_param('~serial_number', None)
|
||||||
|
hz = rospy.get_param('~hz', 30)
|
||||||
|
rospy.loginfo(f"Serial number: {self.serial_number}")
|
||||||
|
|
||||||
|
self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw')
|
||||||
|
self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw')
|
||||||
|
rospy.loginfo(f"RGB topic: {self.rgb_topic}")
|
||||||
|
rospy.loginfo(f"Depth topic: {self.depth_topic}")
|
||||||
|
|
||||||
|
self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10)
|
||||||
|
# self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10)
|
||||||
|
self.rate = rospy.Rate(hz) # 30 Hz
|
||||||
|
self.camera = RealSenseCamera(self.serial_number, False)
|
||||||
|
|
||||||
|
rospy.loginfo("Camera initialized")
|
||||||
|
|
||||||
|
def publish_images(self):
|
||||||
|
self.camera.start_camera()
|
||||||
|
rospy.loginfo("Camera started")
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
result = self.camera.read_frame(True, False, False, False)
|
||||||
|
if result is None:
|
||||||
|
rospy.logerr("Failed to read frame from camera")
|
||||||
|
continue
|
||||||
|
|
||||||
|
color_image, depth_image, _, _ = result
|
||||||
|
|
||||||
|
if color_image is not None or depth_image is not None:
|
||||||
|
rgb_msg = self.create_image_msg(color_image, "bgr8")
|
||||||
|
# depth_msg = self.create_image_msg(depth_image, "mono16")
|
||||||
|
|
||||||
|
self.rgb_pub.publish(rgb_msg)
|
||||||
|
# self.depth_pub.publish(depth_msg)
|
||||||
|
# rospy.loginfo("Published RGB image")
|
||||||
|
else:
|
||||||
|
rospy.logwarn("Received None for color_image or depth_image")
|
||||||
|
|
||||||
|
self.rate.sleep()
|
||||||
|
self.camera.stop_camera()
|
||||||
|
rospy.loginfo("Camera stopped")
|
||||||
|
|
||||||
|
def create_image_msg(self, image, encoding):
|
||||||
|
msg = Image()
|
||||||
|
msg.header = Header()
|
||||||
|
msg.header.stamp = rospy.Time.now()
|
||||||
|
msg.height, msg.width = image.shape[:2]
|
||||||
|
msg.encoding = encoding
|
||||||
|
msg.is_bigendian = False
|
||||||
|
msg.step = image.strides[0]
|
||||||
|
msg.data = np.array(image).tobytes()
|
||||||
|
return msg
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
try:
|
||||||
|
camera_publisher = CameraPublisher()
|
||||||
|
camera_publisher.publish_images()
|
||||||
|
except rospy.ROSInterruptException:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import h5py
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
class DataCollector:
|
||||||
|
def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False):
|
||||||
|
self.arm_axis = 7
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
self.max_timesteps = max_timesteps
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.overwrite = overwrite
|
||||||
|
self.data_dict = {
|
||||||
|
'/observations/qpos': [],
|
||||||
|
'/observations/qvel': [],
|
||||||
|
'/observations/effort': [],
|
||||||
|
'/action': [],
|
||||||
|
}
|
||||||
|
for cam_name in camera_names:
|
||||||
|
self.data_dict[f'/observations/images/{cam_name}'] = []
|
||||||
|
|
||||||
|
# 自动检测和创建数据集目录
|
||||||
|
self.create_dataset_dir()
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_dir(self):
|
||||||
|
# 按照年月日创建目录
|
||||||
|
date_str = datetime.now().strftime("%Y%m%d")
|
||||||
|
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||||
|
if not os.path.exists(self.dataset_dir):
|
||||||
|
os.makedirs(self.dataset_dir)
|
||||||
|
|
||||||
|
def collect_data(self, ts, action):
|
||||||
|
self.data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||||
|
self.data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||||
|
self.data_dict['/observations/effort'].append(ts.observation['effort'])
|
||||||
|
self.data_dict['/action'].append(action)
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||||
|
|
||||||
|
def save_data(self):
|
||||||
|
t0 = time.time()
|
||||||
|
# 保存数据
|
||||||
|
with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root:
|
||||||
|
root.attrs['sim'] = False
|
||||||
|
obs = root.create_group('observations')
|
||||||
|
image = obs.create_group('images')
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
_ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8',
|
||||||
|
chunks=(1, 480, 640, 3))
|
||||||
|
_ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim))
|
||||||
|
_ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim))
|
||||||
|
_ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim))
|
||||||
|
_ = root.create_dataset('action', (self.max_timesteps, self.state_dim))
|
||||||
|
|
||||||
|
for name, array in self.data_dict.items():
|
||||||
|
root[name][...] = array
|
||||||
|
print(f'Saving: {time.time() - t0:.1f} secs')
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_hdf5(self, orign_path, file):
|
||||||
|
self.dataset_path = os.path.join(self.dataset_dir, file)
|
||||||
|
if not os.path.isfile(orign_path):
|
||||||
|
logging.error(f'Dataset does not exist at {orign_path}')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(orign_path, 'r') as root:
|
||||||
|
self.is_sim = root.attrs['sim']
|
||||||
|
self.qpos = root['/observations/qpos'][()]
|
||||||
|
self.qvel = root['/observations/qvel'][()]
|
||||||
|
self.effort = root['/observations/effort'][()]
|
||||||
|
self.action = root['/action'][()]
|
||||||
|
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||||
|
for cam_name in root[f'/observations/images/'].keys()}
|
||||||
|
|
||||||
|
self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis]
|
||||||
|
self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1]
|
||||||
|
|
||||||
|
self.data_dict['/observations/qpos'] = self.qpos
|
||||||
|
self.data_dict['/observations/qvel'] = self.qvel
|
||||||
|
self.data_dict['/observations/effort'] = self.effort
|
||||||
|
self.data_dict['/action'] = self.action
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name]
|
||||||
|
return True
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
"""
|
||||||
|
用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据
|
||||||
|
|
||||||
|
"""
|
||||||
|
dataset_dir = '/home/wang/project/shadow_rm_aloha/data'
|
||||||
|
orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128'
|
||||||
|
dataset_name = 'test'
|
||||||
|
max_timesteps = 300
|
||||||
|
camera_names = ['cam_high','cam_low','cam_left','cam_right']
|
||||||
|
state_dim = 16
|
||||||
|
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||||
|
for file in os.listdir(orign_dir):
|
||||||
|
collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||||
|
orign_path = os.path.join(orign_dir, file)
|
||||||
|
print(orign_path)
|
||||||
|
collector.load_hdf5(orign_path, file)
|
||||||
|
collector.save_data()
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import h5py
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
class DataValidator:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.dataset_dir = config['dataset_dir']
|
||||||
|
self.episode_idx = config['episode_idx']
|
||||||
|
self.joint_names = config['joint_names']
|
||||||
|
self.dataset_name = f'episode_{self.episode_idx}'
|
||||||
|
self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5')
|
||||||
|
self.state_names = self.joint_names + ["gripper"]
|
||||||
|
self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml')
|
||||||
|
|
||||||
|
def load_hdf5(self):
|
||||||
|
if not os.path.isfile(self.dataset_path):
|
||||||
|
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(self.dataset_path, 'r') as root:
|
||||||
|
self.is_sim = root.attrs['sim']
|
||||||
|
self.qpos = root['/observations/qpos'][()]
|
||||||
|
# self.qvel = root['/observations/qvel'][()]
|
||||||
|
# self.effort = root['/observations/effort'][()]
|
||||||
|
self.action = root['/action'][()]
|
||||||
|
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||||
|
for cam_name in root[f'/observations/images/'].keys()}
|
||||||
|
|
||||||
|
def validate_data(self):
|
||||||
|
# 验证位置数据
|
||||||
|
if not self.qpos.shape[1] == 14:
|
||||||
|
logging.error('qpos shape does not match expected number of joints')
|
||||||
|
return False
|
||||||
|
|
||||||
|
logging.info('Data validation passed')
|
||||||
|
return True
|
||||||
|
|
||||||
|
def control_robot(self):
|
||||||
|
self.arm.set_joint_position(self.qpos[0][0:6])
|
||||||
|
for qpos in self.qpos:
|
||||||
|
logging.info(f'qpos: {qpos}')
|
||||||
|
self.arm.set_joint_canfd_position(qpos[7:13])
|
||||||
|
self.arm.set_gripper_position(qpos[13])
|
||||||
|
time.sleep(0.035)
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
self.load_hdf5()
|
||||||
|
if self.validate_data():
|
||||||
|
self.control_robot()
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
with open(config_path, 'r') as file:
|
||||||
|
return yaml.safe_load(file)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||||
|
validator = DataValidator(config)
|
||||||
|
validator.run()
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import h5py
|
||||||
|
import yaml
|
||||||
|
import logging
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
|
||||||
|
class DataVisualizer:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.dataset_dir = config['dataset_dir']
|
||||||
|
self.episode_idx = config['episode_idx']
|
||||||
|
self.dt = 1/config['FPS']
|
||||||
|
self.joint_names = config['joint_names']
|
||||||
|
self.state_names = self.joint_names + ["gripper"]
|
||||||
|
# self.camera_names = config['camera_names']
|
||||||
|
|
||||||
|
def join_file_path(self, file_name):
|
||||||
|
self.dataset_path = os.path.join(self.dataset_dir, file_name)
|
||||||
|
|
||||||
|
def load_hdf5(self):
|
||||||
|
if not os.path.isfile(self.dataset_path):
|
||||||
|
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(self.dataset_path, 'r') as root:
|
||||||
|
self.is_sim = root.attrs['sim']
|
||||||
|
self.qpos = root['/observations/qpos'][()]
|
||||||
|
# self.qvel = root['/observations/qvel'][()]
|
||||||
|
# self.effort = root['/observations/effort'][()]
|
||||||
|
self.action = root['/action'][()]
|
||||||
|
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||||
|
for cam_name in root[f'/observations/images/'].keys()}
|
||||||
|
|
||||||
|
def save_videos(self, video, dt, video_path=None):
|
||||||
|
if isinstance(video, list):
|
||||||
|
cam_names = list(video[0].keys())
|
||||||
|
h, w, _ = video[0][cam_names[0]].shape
|
||||||
|
w = w * len(cam_names)
|
||||||
|
fps = int(1 / dt)
|
||||||
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
for image_dict in video:
|
||||||
|
images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names]
|
||||||
|
out.write(np.concatenate(images, axis=1))
|
||||||
|
out.release()
|
||||||
|
logging.info(f'Saved video to: {video_path}')
|
||||||
|
elif isinstance(video, dict):
|
||||||
|
cam_names = list(video.keys())
|
||||||
|
all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2)
|
||||||
|
n_frames, h, w, _ = all_cam_videos.shape
|
||||||
|
fps = int(1 / dt)
|
||||||
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
for t in range(n_frames):
|
||||||
|
out.write(all_cam_videos[t][:, :, [2, 1, 0]])
|
||||||
|
out.release()
|
||||||
|
logging.info(f'Saved video to: {video_path}')
|
||||||
|
|
||||||
|
def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None):
|
||||||
|
label1, label2 = label_overwrite if label_overwrite else ('State', 'Command')
|
||||||
|
qpos = np.array(qpos_list)
|
||||||
|
command = np.array(command_list)
|
||||||
|
num_ts, num_dim = qpos.shape
|
||||||
|
logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}')
|
||||||
|
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||||
|
|
||||||
|
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.plot(qpos[:, dim_idx], label=label1)
|
||||||
|
ax.plot(command[:, dim_idx], label=label2)
|
||||||
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||||
|
ax.legend()
|
||||||
|
if ylim:
|
||||||
|
ax.set_ylim(ylim)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
logging.info(f'Saved qpos plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def visualize_single(self, data_list, label, plot_path, ylim=None):
|
||||||
|
data = np.array(data_list)
|
||||||
|
num_ts, num_dim = data.shape
|
||||||
|
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||||
|
|
||||||
|
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.plot(data[:, dim_idx], label=label)
|
||||||
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||||
|
ax.legend()
|
||||||
|
if ylim:
|
||||||
|
ax.set_ylim(ylim)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
logging.info(f'Saved {label} plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def visualize_timestamp(self, t_list):
|
||||||
|
plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png')
|
||||||
|
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
|
||||||
|
t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list])
|
||||||
|
|
||||||
|
axs[0].plot(np.arange(len(t_float)), t_float)
|
||||||
|
axs[0].set_title('Camera frame timestamps')
|
||||||
|
axs[0].set_xlabel('timestep')
|
||||||
|
axs[0].set_ylabel('time (sec)')
|
||||||
|
|
||||||
|
axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:])
|
||||||
|
axs[1].set_title('dt')
|
||||||
|
axs[1].set_xlabel('timestep')
|
||||||
|
axs[1].set_ylabel('time (sec)')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
logging.info(f'Saved timestamp plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
for file_name in os.listdir(self.dataset_dir):
|
||||||
|
if file_name.endswith('.hdf5'):
|
||||||
|
self.join_file_path(file_name)
|
||||||
|
self.load_hdf5()
|
||||||
|
video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4'))
|
||||||
|
self.save_videos(self.image_dict, self.dt, video_path)
|
||||||
|
qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png'))
|
||||||
|
self.visualize_joints(self.qpos, self.action, qpos_plot_path)
|
||||||
|
# effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png'))
|
||||||
|
# self.visualize_single(self.effort, 'effort', effort_plot_path)
|
||||||
|
# error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png'))
|
||||||
|
# self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path)
|
||||||
|
# self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(config_path):
|
||||||
|
with open(config_path, 'r') as file:
|
||||||
|
return yaml.safe_load(file)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||||
|
visualizer = DataVisualizer(config)
|
||||||
|
visualizer.run()
|
||||||
@@ -0,0 +1,180 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import rospy
|
||||||
|
import dm_env
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
from datetime import datetime
|
||||||
|
from sensor_msgs.msg import Image, JointState
|
||||||
|
from shadow_rm_robot.realman_arm import RmArm
|
||||||
|
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||||
|
|
||||||
|
|
||||||
|
class DataSynchronizer:
|
||||||
|
def __init__(self, config_path="config"):
|
||||||
|
rospy.init_node("synchronizer", anonymous=True)
|
||||||
|
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
|
||||||
|
self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"]
|
||||||
|
self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"]
|
||||||
|
self.arm_axis = config["robot_env"]["arm_axis"]
|
||||||
|
self.camera_names = config["camera_names"]
|
||||||
|
|
||||||
|
# 创建订阅者
|
||||||
|
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||||
|
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||||
|
self.camera_bottom_sub = Subscriber(
|
||||||
|
config["ros_topics"]["camera_bottom"], Image
|
||||||
|
)
|
||||||
|
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||||
|
|
||||||
|
self.left_slave_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["left_slave_arm_sub"], JointState
|
||||||
|
)
|
||||||
|
self.right_slave_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["right_slave_arm_sub"], JointState
|
||||||
|
)
|
||||||
|
self.left_slave_arm_pub = rospy.Publisher(
|
||||||
|
config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1
|
||||||
|
)
|
||||||
|
self.right_slave_arm_pub = rospy.Publisher(
|
||||||
|
config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建同步器
|
||||||
|
self.ats = ApproximateTimeSynchronizer(
|
||||||
|
[
|
||||||
|
self.camera_left_sub,
|
||||||
|
self.camera_right_sub,
|
||||||
|
self.camera_bottom_sub,
|
||||||
|
self.camera_head_sub,
|
||||||
|
self.left_slave_arm_sub,
|
||||||
|
self.right_slave_arm_sub,
|
||||||
|
],
|
||||||
|
queue_size=1,
|
||||||
|
slop=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ats.registerCallback(self.callback)
|
||||||
|
self.ts = None
|
||||||
|
self.is_frist_step = True
|
||||||
|
|
||||||
|
def callback(
|
||||||
|
self,
|
||||||
|
camera_left_img,
|
||||||
|
camera_right_img,
|
||||||
|
camera_bottom_img,
|
||||||
|
camera_head_img,
|
||||||
|
left_slave_arm,
|
||||||
|
right_slave_arm,
|
||||||
|
):
|
||||||
|
|
||||||
|
# 将ROS图像消息转换为NumPy数组
|
||||||
|
camera_left_np_img = np.frombuffer(
|
||||||
|
camera_left_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||||
|
camera_right_np_img = np.frombuffer(
|
||||||
|
camera_right_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||||
|
camera_bottom_np_img = np.frombuffer(
|
||||||
|
camera_bottom_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||||
|
camera_head_np_img = np.frombuffer(
|
||||||
|
camera_head_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||||
|
|
||||||
|
left_slave_arm_angle = left_slave_arm.position
|
||||||
|
left_slave_arm_velocity = left_slave_arm.velocity
|
||||||
|
left_slave_arm_force = left_slave_arm.effort
|
||||||
|
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||||
|
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||||
|
|
||||||
|
right_slave_arm_angle = right_slave_arm.position
|
||||||
|
right_slave_arm_velocity = right_slave_arm.velocity
|
||||||
|
right_slave_arm_force = right_slave_arm.effort
|
||||||
|
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||||
|
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||||
|
|
||||||
|
# 收集数据
|
||||||
|
obs = collections.OrderedDict(
|
||||||
|
{
|
||||||
|
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||||
|
"qvel": np.concatenate(
|
||||||
|
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||||
|
),
|
||||||
|
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||||
|
"images": {
|
||||||
|
self.camera_names[0]: camera_head_np_img,
|
||||||
|
self.camera_names[1]: camera_bottom_np_img,
|
||||||
|
self.camera_names[2]: camera_left_np_img,
|
||||||
|
self.camera_names[3]: camera_right_np_img,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.ts = dm_env.TimeStep(
|
||||||
|
step_type=(
|
||||||
|
dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID
|
||||||
|
),
|
||||||
|
reward=0.0,
|
||||||
|
discount=1.0,
|
||||||
|
observation=obs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
|
||||||
|
left_joint_state = JointState()
|
||||||
|
left_joint_state.header.stamp = rospy.Time.now()
|
||||||
|
left_joint_state.name = ["joint_j"]
|
||||||
|
left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1]
|
||||||
|
|
||||||
|
right_joint_state = JointState()
|
||||||
|
right_joint_state.header.stamp = rospy.Time.now()
|
||||||
|
right_joint_state.name = ["joint_j"]
|
||||||
|
right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1]
|
||||||
|
|
||||||
|
self.left_slave_arm_pub.publish(left_joint_state)
|
||||||
|
self.right_slave_arm_pub.publish(right_joint_state)
|
||||||
|
while self.ts is None:
|
||||||
|
time.sleep(0.002)
|
||||||
|
|
||||||
|
return self.ts
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def step(self, target_angle):
|
||||||
|
self.is_frist_step = False
|
||||||
|
left_joint_state = JointState()
|
||||||
|
left_joint_state.header.stamp = rospy.Time.now()
|
||||||
|
left_joint_state.name = ["joint_canfd"]
|
||||||
|
left_joint_state.position = target_angle[0 : self.arm_axis + 1]
|
||||||
|
# print("left_joint_state: ", left_joint_state)
|
||||||
|
|
||||||
|
right_joint_state = JointState()
|
||||||
|
right_joint_state.header.stamp = rospy.Time.now()
|
||||||
|
right_joint_state.name = ["joint_canfd"]
|
||||||
|
right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2]
|
||||||
|
# print("right_joint_state: ", right_joint_state)
|
||||||
|
|
||||||
|
self.left_slave_arm_pub.publish(left_joint_state)
|
||||||
|
self.right_slave_arm_pub.publish(right_joint_state)
|
||||||
|
# time.sleep(0.013)
|
||||||
|
return self.ts
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
rospy.loginfo("Starting ROS spin")
|
||||||
|
data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle])
|
||||||
|
self.reset()
|
||||||
|
# print("data: ", data)
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
self.step(data)
|
||||||
|
rospy.sleep(0.010)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml")
|
||||||
|
start_time = time.time()
|
||||||
|
synchronizer.run()
|
||||||
@@ -0,0 +1,280 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import h5py
|
||||||
|
import yaml
|
||||||
|
import rospy
|
||||||
|
import dm_env
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
from datetime import datetime
|
||||||
|
from std_msgs.msg import Int32MultiArray
|
||||||
|
from sensor_msgs.msg import Image, JointState
|
||||||
|
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||||
|
|
||||||
|
|
||||||
|
class DataCollector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_dir,
|
||||||
|
dataset_name,
|
||||||
|
max_timesteps,
|
||||||
|
camera_names,
|
||||||
|
state_dim,
|
||||||
|
overwrite=False,
|
||||||
|
):
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
self.max_timesteps = max_timesteps
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.overwrite = overwrite
|
||||||
|
self.init_dict()
|
||||||
|
self.create_dataset_dir()
|
||||||
|
|
||||||
|
def init_dict(self):
|
||||||
|
self.data_dict = {
|
||||||
|
"/observations/qpos": [],
|
||||||
|
"/observations/qvel": [],
|
||||||
|
"/observations/effort": [],
|
||||||
|
"/action": [],
|
||||||
|
}
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||||
|
|
||||||
|
def create_dataset_dir(self):
|
||||||
|
# 按照年月日创建目录
|
||||||
|
date_str = datetime.now().strftime("%Y%m%d")
|
||||||
|
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||||
|
if not os.path.exists(self.dataset_dir):
|
||||||
|
os.makedirs(self.dataset_dir)
|
||||||
|
|
||||||
|
def create_file(self):
|
||||||
|
# 检查数据集名称是否存在,如果存在则递增名称
|
||||||
|
counter = 0
|
||||||
|
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||||
|
if not self.overwrite:
|
||||||
|
while os.path.exists(dataset_path + ".hdf5"):
|
||||||
|
dataset_path = os.path.join(
|
||||||
|
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||||
|
)
|
||||||
|
counter += 1
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
|
||||||
|
def collect_data(self, ts, action):
|
||||||
|
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||||
|
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||||
|
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||||
|
self.data_dict["/action"].append(action)
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||||
|
ts.observation["images"][cam_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_data(self):
|
||||||
|
self.create_file()
|
||||||
|
t0 = time.time()
|
||||||
|
# 保存数据
|
||||||
|
with h5py.File(
|
||||||
|
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||||
|
) as root:
|
||||||
|
root.attrs["sim"] = False
|
||||||
|
obs = root.create_group("observations")
|
||||||
|
image = obs.create_group("images")
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
_ = image.create_dataset(
|
||||||
|
cam_name,
|
||||||
|
(self.max_timesteps, 480, 640, 3),
|
||||||
|
dtype="uint8",
|
||||||
|
chunks=(1, 480, 640, 3),
|
||||||
|
)
|
||||||
|
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||||
|
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||||
|
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||||
|
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||||
|
|
||||||
|
for name, array in self.data_dict.items():
|
||||||
|
root[name][...] = array
|
||||||
|
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class DataSynchronizer:
|
||||||
|
def __init__(self, config_path="config"):
|
||||||
|
rospy.init_node("synchronizer", anonymous=True)
|
||||||
|
rospy.loginfo("ROS node initialized")
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
self.arm_axis = config["arm_axis"]
|
||||||
|
# 创建订阅者
|
||||||
|
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||||
|
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||||
|
self.camera_bottom_sub = Subscriber(
|
||||||
|
config["ros_topics"]["camera_bottom"], Image
|
||||||
|
)
|
||||||
|
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||||
|
|
||||||
|
self.left_master_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["left_master_arm"], JointState
|
||||||
|
)
|
||||||
|
self.left_slave_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["left_slave_arm"], JointState
|
||||||
|
)
|
||||||
|
self.right_master_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["right_master_arm"], JointState
|
||||||
|
)
|
||||||
|
self.right_slave_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["right_slave_arm"], JointState
|
||||||
|
)
|
||||||
|
self.left_aloha_state_pub = rospy.Subscriber(
|
||||||
|
config["ros_topics"]["left_aloha_state"],
|
||||||
|
Int32MultiArray,
|
||||||
|
self.aloha_state_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
rospy.loginfo("Subscribers created")
|
||||||
|
self.camera_names = config["camera_names"]
|
||||||
|
|
||||||
|
# 创建同步器
|
||||||
|
self.ats = ApproximateTimeSynchronizer(
|
||||||
|
[
|
||||||
|
self.camera_left_sub,
|
||||||
|
self.camera_right_sub,
|
||||||
|
self.camera_bottom_sub,
|
||||||
|
self.camera_head_sub,
|
||||||
|
self.left_master_arm_sub,
|
||||||
|
self.left_slave_arm_sub,
|
||||||
|
self.right_master_arm_sub,
|
||||||
|
self.right_slave_arm_sub,
|
||||||
|
],
|
||||||
|
queue_size=1,
|
||||||
|
slop=0.05,
|
||||||
|
)
|
||||||
|
self.ats.registerCallback(self.callback)
|
||||||
|
rospy.loginfo("Time synchronizer created and callback registered")
|
||||||
|
|
||||||
|
self.data_collector = DataCollector(
|
||||||
|
dataset_dir=config["dataset_dir"],
|
||||||
|
dataset_name=config["dataset_name"],
|
||||||
|
max_timesteps=config["max_timesteps"],
|
||||||
|
camera_names=config["camera_names"],
|
||||||
|
state_dim=config["state_dim"],
|
||||||
|
overwrite=config["overwrite"],
|
||||||
|
)
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
self.begin_collect = False
|
||||||
|
self.last_time = None
|
||||||
|
|
||||||
|
def callback(
|
||||||
|
self,
|
||||||
|
camera_left_img,
|
||||||
|
camera_right_img,
|
||||||
|
camera_bottom_img,
|
||||||
|
camera_head_img,
|
||||||
|
left_master_arm,
|
||||||
|
left_slave_arm,
|
||||||
|
right_master_arm,
|
||||||
|
right_slave_arm,
|
||||||
|
):
|
||||||
|
if self.begin_collect:
|
||||||
|
self.timesteps_collected += 1
|
||||||
|
rospy.loginfo(
|
||||||
|
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
return
|
||||||
|
if self.timesteps_collected == 0:
|
||||||
|
return
|
||||||
|
current_time = time.time()
|
||||||
|
if self.last_time is not None:
|
||||||
|
frequency = 1.0 / (current_time - self.last_time)
|
||||||
|
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||||
|
self.last_time = current_time
|
||||||
|
# 将ROS图像消息转换为NumPy数组
|
||||||
|
camera_left_np_img = np.frombuffer(
|
||||||
|
camera_left_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||||
|
camera_right_np_img = np.frombuffer(
|
||||||
|
camera_right_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||||
|
camera_bottom_np_img = np.frombuffer(
|
||||||
|
camera_bottom_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||||
|
camera_head_np_img = np.frombuffer(
|
||||||
|
camera_head_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||||
|
|
||||||
|
# 提取臂的角度,速度,力
|
||||||
|
left_master_arm_angle = left_master_arm.position
|
||||||
|
# left_master_arm_velocity = left_master_arm.velocity
|
||||||
|
# left_master_arm_force = left_master_arm.effort
|
||||||
|
|
||||||
|
left_slave_arm_angle = left_slave_arm.position
|
||||||
|
left_slave_arm_velocity = left_slave_arm.velocity
|
||||||
|
left_slave_arm_force = left_slave_arm.effort
|
||||||
|
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||||
|
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||||
|
|
||||||
|
right_master_arm_angle = right_master_arm.position
|
||||||
|
# right_master_arm_velocity = right_master_arm.velocity
|
||||||
|
# right_master_arm_force = right_master_arm.effort
|
||||||
|
|
||||||
|
right_slave_arm_angle = right_slave_arm.position
|
||||||
|
right_slave_arm_velocity = right_slave_arm.velocity
|
||||||
|
right_slave_arm_force = right_slave_arm.effort
|
||||||
|
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||||
|
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||||
|
|
||||||
|
# 收集数据
|
||||||
|
obs = collections.OrderedDict(
|
||||||
|
{
|
||||||
|
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||||
|
"qvel": np.concatenate(
|
||||||
|
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||||
|
),
|
||||||
|
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||||
|
"images": {
|
||||||
|
self.camera_names[0]: camera_head_np_img,
|
||||||
|
self.camera_names[1]: camera_bottom_np_img,
|
||||||
|
self.camera_names[2]: camera_left_np_img,
|
||||||
|
self.camera_names[3]: camera_right_np_img,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(self.camera_names[0])
|
||||||
|
ts = dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||||
|
)
|
||||||
|
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||||
|
self.data_collector.collect_data(ts, action)
|
||||||
|
|
||||||
|
# 检查是否收集了足够的数据
|
||||||
|
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||||
|
self.data_collector.save_data()
|
||||||
|
|
||||||
|
rospy.loginfo("Data collection complete")
|
||||||
|
|
||||||
|
self.data_collector.init_dict()
|
||||||
|
self.begin_collect = False
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
|
||||||
|
def aloha_state_callback(self, data):
|
||||||
|
if not self.begin_collect:
|
||||||
|
self.aloha_state = data.data
|
||||||
|
print(self.aloha_state[0], self.aloha_state[1])
|
||||||
|
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||||
|
self.begin_collect = True
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
rospy.loginfo("Starting ROS spin")
|
||||||
|
|
||||||
|
rospy.spin()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
synchronizer = DataSynchronizer(
|
||||||
|
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||||
|
)
|
||||||
|
synchronizer.run()
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
<launch>
|
||||||
|
<!-- 左从臂节点 -->
|
||||||
|
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||||
|
<param name="aloha_state_topic" value="/left_slave_arm_aloha_state" type= "string"/>
|
||||||
|
<param name="hz" value="120" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右从臂节点 -->
|
||||||
|
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||||
|
<param name="aloha_state_topic" value="/right_slave_arm_aloha_state" type= "string"/>
|
||||||
|
<param name="hz" value="120" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 左主臂节点 -->
|
||||||
|
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||||
|
<param name="hz" value="90" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右主臂节点 -->
|
||||||
|
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_right_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||||
|
<param name="hz" value="90" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右臂相机节点 -->
|
||||||
|
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="151222072576" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 左臂相机节点 -->
|
||||||
|
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="150622070125" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 顶部相机节点 -->
|
||||||
|
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="241122071186" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 底部相机节点 -->
|
||||||
|
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="152122078546" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
</launch>
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
<launch>
|
||||||
|
<!-- 左从臂节点 -->
|
||||||
|
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||||
|
<param name="joint_actions_topic" value="/left_slave_arm_joint_actions" type= "string"/>
|
||||||
|
<param name="hz" value="90" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右从臂节点 -->
|
||||||
|
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||||
|
<param name="joint_actions_topic" value="/right_slave_arm_joint_actions" type= "string"/>
|
||||||
|
<param name="hz" value="90" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右臂相机节点 -->
|
||||||
|
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="151222072576" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 左臂相机节点 -->
|
||||||
|
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="150622070125" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 顶部相机节点 -->
|
||||||
|
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="241122071186" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 底部相机节点 -->
|
||||||
|
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="152122078546" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
</launch>
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
<launch>
|
||||||
|
<!-- 左从臂节点 -->
|
||||||
|
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_left_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右从臂节点 -->
|
||||||
|
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_right_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 左主臂节点 -->
|
||||||
|
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_left_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右主臂节点 -->
|
||||||
|
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||||
|
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_right_arm.yaml" type= "string"/>
|
||||||
|
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 右臂相机节点 -->
|
||||||
|
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="216322070299" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 左臂相机节点 -->
|
||||||
|
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="216322074992" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 顶部相机节点 -->
|
||||||
|
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="215322076086" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
|
||||||
|
<!-- 底部相机节点 -->
|
||||||
|
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||||
|
<param name="serial_number" value="215222074360" type= "string"/>
|
||||||
|
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||||
|
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||||
|
<param name="hz" value="50" type= "int"/>
|
||||||
|
</node>
|
||||||
|
</launch>
|
||||||
@@ -0,0 +1,284 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import time
|
||||||
|
import h5py
|
||||||
|
import yaml
|
||||||
|
import json
|
||||||
|
import rospy
|
||||||
|
import dm_env
|
||||||
|
import socket
|
||||||
|
import numpy as np
|
||||||
|
import collections
|
||||||
|
from datetime import datetime
|
||||||
|
from std_msgs.msg import Int32MultiArray
|
||||||
|
from sensor_msgs.msg import Image, JointState
|
||||||
|
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||||
|
|
||||||
|
|
||||||
|
class DataCollector:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_dir,
|
||||||
|
dataset_name,
|
||||||
|
max_timesteps,
|
||||||
|
camera_names,
|
||||||
|
state_dim,
|
||||||
|
overwrite=False,
|
||||||
|
):
|
||||||
|
self.dataset_dir = dataset_dir
|
||||||
|
self.dataset_name = dataset_name
|
||||||
|
self.max_timesteps = max_timesteps
|
||||||
|
self.camera_names = camera_names
|
||||||
|
self.state_dim = state_dim
|
||||||
|
self.overwrite = overwrite
|
||||||
|
self.init_dict()
|
||||||
|
self.create_dataset_dir()
|
||||||
|
|
||||||
|
def init_dict(self):
|
||||||
|
self.data_dict = {
|
||||||
|
"/observations/qpos": [],
|
||||||
|
"/observations/qvel": [],
|
||||||
|
"/observations/effort": [],
|
||||||
|
"/action": [],
|
||||||
|
}
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||||
|
|
||||||
|
def create_dataset_dir(self):
|
||||||
|
# 按照年月日创建目录
|
||||||
|
date_str = datetime.now().strftime("%Y%m%d")
|
||||||
|
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||||
|
if not os.path.exists(self.dataset_dir):
|
||||||
|
os.makedirs(self.dataset_dir)
|
||||||
|
|
||||||
|
def create_file(self):
|
||||||
|
# 检查数据集名称是否存在,如果存在则递增名称
|
||||||
|
counter = 0
|
||||||
|
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||||
|
if not self.overwrite:
|
||||||
|
while os.path.exists(dataset_path + ".hdf5"):
|
||||||
|
dataset_path = os.path.join(
|
||||||
|
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||||
|
)
|
||||||
|
counter += 1
|
||||||
|
self.dataset_path = dataset_path
|
||||||
|
|
||||||
|
def collect_data(self, ts, action):
|
||||||
|
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||||
|
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||||
|
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||||
|
self.data_dict["/action"].append(action)
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||||
|
ts.observation["images"][cam_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_data(self):
|
||||||
|
self.create_file()
|
||||||
|
t0 = time.time()
|
||||||
|
# 保存数据
|
||||||
|
with h5py.File(
|
||||||
|
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||||
|
) as root:
|
||||||
|
root.attrs["sim"] = False
|
||||||
|
obs = root.create_group("observations")
|
||||||
|
image = obs.create_group("images")
|
||||||
|
for cam_name in self.camera_names:
|
||||||
|
_ = image.create_dataset(
|
||||||
|
cam_name,
|
||||||
|
(self.max_timesteps, 480, 640, 3),
|
||||||
|
dtype="uint8",
|
||||||
|
chunks=(1, 480, 640, 3),
|
||||||
|
)
|
||||||
|
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||||
|
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||||
|
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||||
|
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||||
|
|
||||||
|
for name, array in self.data_dict.items():
|
||||||
|
root[name][...] = array
|
||||||
|
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class DataSynchronizer:
|
||||||
|
def __init__(self, config_path="config"):
|
||||||
|
rospy.init_node("synchronizer", anonymous=True)
|
||||||
|
rospy.loginfo("ROS node initialized")
|
||||||
|
with open(config_path, "r") as file:
|
||||||
|
config = yaml.safe_load(file)
|
||||||
|
self.arm_axis = config["arm_axis"]
|
||||||
|
# 创建订阅者
|
||||||
|
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||||
|
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||||
|
self.camera_bottom_sub = Subscriber(
|
||||||
|
config["ros_topics"]["camera_bottom"], Image
|
||||||
|
)
|
||||||
|
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||||
|
|
||||||
|
self.left_master_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["left_master_arm"], JointState
|
||||||
|
)
|
||||||
|
self.left_slave_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["left_slave_arm"], JointState
|
||||||
|
)
|
||||||
|
self.right_master_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["right_master_arm"], JointState
|
||||||
|
)
|
||||||
|
self.right_slave_arm_sub = Subscriber(
|
||||||
|
config["ros_topics"]["right_slave_arm"], JointState
|
||||||
|
)
|
||||||
|
self.left_aloha_state_pub = rospy.Subscriber(
|
||||||
|
config["ros_topics"]["left_aloha_state"],
|
||||||
|
Int32MultiArray,
|
||||||
|
self.aloha_state_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
rospy.loginfo("Subscribers created")
|
||||||
|
|
||||||
|
# 创建同步器
|
||||||
|
self.ats = ApproximateTimeSynchronizer(
|
||||||
|
[
|
||||||
|
self.camera_left_sub,
|
||||||
|
self.camera_right_sub,
|
||||||
|
self.camera_bottom_sub,
|
||||||
|
self.camera_head_sub,
|
||||||
|
self.left_master_arm_sub,
|
||||||
|
self.left_slave_arm_sub,
|
||||||
|
self.right_master_arm_sub,
|
||||||
|
self.right_slave_arm_sub,
|
||||||
|
],
|
||||||
|
queue_size=1,
|
||||||
|
slop=0.05,
|
||||||
|
)
|
||||||
|
self.ats.registerCallback(self.callback)
|
||||||
|
rospy.loginfo("Time synchronizer created and callback registered")
|
||||||
|
|
||||||
|
self.data_collector = DataCollector(
|
||||||
|
dataset_dir=config["dataset_dir"],
|
||||||
|
dataset_name=config["dataset_name"],
|
||||||
|
max_timesteps=config["max_timesteps"],
|
||||||
|
camera_names=config["camera_names"],
|
||||||
|
state_dim=config["state_dim"],
|
||||||
|
overwrite=config["overwrite"],
|
||||||
|
)
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
self.begin_collect = False
|
||||||
|
self.last_time = None
|
||||||
|
|
||||||
|
def callback(
|
||||||
|
self,
|
||||||
|
camera_left_img,
|
||||||
|
camera_right_img,
|
||||||
|
camera_bottom_img,
|
||||||
|
camera_head_img,
|
||||||
|
left_master_arm,
|
||||||
|
left_slave_arm,
|
||||||
|
right_master_arm,
|
||||||
|
right_slave_arm,
|
||||||
|
):
|
||||||
|
if self.begin_collect:
|
||||||
|
self.timesteps_collected += 1
|
||||||
|
rospy.loginfo(
|
||||||
|
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
return
|
||||||
|
if self.timesteps_collected == 0:
|
||||||
|
return
|
||||||
|
current_time = time.time()
|
||||||
|
if self.last_time is not None:
|
||||||
|
frequency = 1.0 / (current_time - self.last_time)
|
||||||
|
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||||
|
self.last_time = current_time
|
||||||
|
# 将ROS图像消息转换为NumPy数组
|
||||||
|
camera_left_np_img = np.frombuffer(
|
||||||
|
camera_left_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||||
|
camera_right_np_img = np.frombuffer(
|
||||||
|
camera_right_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||||
|
camera_bottom_np_img = np.frombuffer(
|
||||||
|
camera_bottom_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||||
|
camera_head_np_img = np.frombuffer(
|
||||||
|
camera_head_img.data, dtype=np.uint8
|
||||||
|
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||||
|
|
||||||
|
# 提取臂的角度,速度,力
|
||||||
|
left_master_arm_angle = left_master_arm.position
|
||||||
|
# left_master_arm_velocity = left_master_arm.velocity
|
||||||
|
# left_master_arm_force = left_master_arm.effort
|
||||||
|
|
||||||
|
left_slave_arm_angle = left_slave_arm.position
|
||||||
|
left_slave_arm_velocity = left_slave_arm.velocity
|
||||||
|
left_slave_arm_force = left_slave_arm.effort
|
||||||
|
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||||
|
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||||
|
|
||||||
|
right_master_arm_angle = right_master_arm.position
|
||||||
|
# right_master_arm_velocity = right_master_arm.velocity
|
||||||
|
# right_master_arm_force = right_master_arm.effort
|
||||||
|
|
||||||
|
right_slave_arm_angle = right_slave_arm.position
|
||||||
|
right_slave_arm_velocity = right_slave_arm.velocity
|
||||||
|
right_slave_arm_force = right_slave_arm.effort
|
||||||
|
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||||
|
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||||
|
|
||||||
|
# 收集数据
|
||||||
|
obs = collections.OrderedDict(
|
||||||
|
{
|
||||||
|
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||||
|
"qvel": np.concatenate(
|
||||||
|
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||||
|
),
|
||||||
|
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||||
|
"images": {
|
||||||
|
"cam_front": camera_head_np_img,
|
||||||
|
"cam_low": camera_bottom_np_img,
|
||||||
|
"cam_left": camera_left_np_img,
|
||||||
|
"cam_right": camera_right_np_img,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ts = dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||||
|
)
|
||||||
|
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||||
|
self.data_collector.collect_data(ts, action)
|
||||||
|
|
||||||
|
# 检查是否收集了足够的数据
|
||||||
|
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||||
|
self.data_collector.save_data()
|
||||||
|
|
||||||
|
rospy.loginfo("Data collection complete")
|
||||||
|
|
||||||
|
self.data_collector.init_dict()
|
||||||
|
self.begin_collect = False
|
||||||
|
self.timesteps_collected = 0
|
||||||
|
|
||||||
|
def aloha_state_callback(self, data):
|
||||||
|
if not self.begin_collect:
|
||||||
|
self.aloha_state = data.data
|
||||||
|
print(self.aloha_state[0], self.aloha_state[1])
|
||||||
|
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||||
|
self.begin_collect = True
|
||||||
|
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
rospy.loginfo("Starting ROS spin")
|
||||||
|
|
||||||
|
rospy.spin()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
synchronizer = DataSynchronizer(
|
||||||
|
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||||
|
)
|
||||||
|
synchronizer.run()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
<?xml version="1.0"?>
|
||||||
|
<package format="2">
|
||||||
|
<name>shadow_rm_aloha</name>
|
||||||
|
<version>0.0.1</version>
|
||||||
|
<description>The shadow_rm_aloha package</description>
|
||||||
|
|
||||||
|
<maintainer email="your_email@example.com">Your Name</maintainer>
|
||||||
|
|
||||||
|
<license>TODO</license>
|
||||||
|
|
||||||
|
<buildtool_depend>catkin</buildtool_depend>
|
||||||
|
|
||||||
|
<build_depend>rospy</build_depend>
|
||||||
|
<build_depend>sensor_msgs</build_depend>
|
||||||
|
<build_depend>std_msgs</build_depend>
|
||||||
|
<build_depend>cv_bridge</build_depend>
|
||||||
|
<build_depend>image_transport</build_depend>
|
||||||
|
<build_depend>message_generation</build_depend>
|
||||||
|
<build_depend>message_runtime</build_depend>
|
||||||
|
|
||||||
|
<exec_depend>rospy</exec_depend>
|
||||||
|
<exec_depend>sensor_msgs</exec_depend>
|
||||||
|
<exec_depend>std_msgs</exec_depend>
|
||||||
|
<exec_depend>cv_bridge</exec_depend>
|
||||||
|
<exec_depend>image_transport</exec_depend>
|
||||||
|
<exec_depend>message_runtime</exec_depend>
|
||||||
|
|
||||||
|
|
||||||
|
<export>
|
||||||
|
</export>
|
||||||
|
</package>
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# GetArmStatus.srv
|
||||||
|
|
||||||
|
---
|
||||||
|
sensor_msgs/JointState joint_status
|
||||||
|
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
# GetImage.srv
|
||||||
|
---
|
||||||
|
bool success
|
||||||
|
sensor_msgs/Image image
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
# MoveArm.srv
|
||||||
|
float32[] joint_angle
|
||||||
|
---
|
||||||
|
bool success
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
__version__ = '0.1.0'
|
||||||
49
realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py
Normal file
49
realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
|
import time
|
||||||
|
|
||||||
|
def collect_data(arm_id, cam_id, data_queue, lock):
|
||||||
|
while True:
|
||||||
|
# 模拟数据采集
|
||||||
|
arm_data = f"Arm {arm_id} data"
|
||||||
|
cam_data = f"Cam {cam_id} data"
|
||||||
|
|
||||||
|
# 获取当前时间戳
|
||||||
|
timestamp = time.time()
|
||||||
|
|
||||||
|
# 将数据放入队列
|
||||||
|
with lock:
|
||||||
|
data_queue.put((timestamp, arm_data, cam_data))
|
||||||
|
|
||||||
|
# 模拟高帧率
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
num_arms = 4
|
||||||
|
num_cams = 4
|
||||||
|
|
||||||
|
# 创建队列和锁
|
||||||
|
data_queue = mp.Queue()
|
||||||
|
lock = mp.Lock()
|
||||||
|
|
||||||
|
# 创建进程
|
||||||
|
processes = []
|
||||||
|
for i in range(num_arms):
|
||||||
|
p = mp.Process(target=collect_data, args=(i, i, data_queue, lock))
|
||||||
|
processes.append(p)
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
# 主进程处理数据
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if not data_queue.empty():
|
||||||
|
with lock:
|
||||||
|
timestamp, arm_data, cam_data = data_queue.get()
|
||||||
|
print(f"Timestamp: {timestamp}, {arm_data}, {cam_data}")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
for p in processes:
|
||||||
|
p.terminate()
|
||||||
|
for p in processes:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from shadow_rm_aloha.data_sub_process.aloha_data_synchronizer import DataCollector
|
||||||
|
|
||||||
|
def test_create_dataset_dir():
|
||||||
|
# 设置测试参数
|
||||||
|
dataset_dir = './test_data/dataset'
|
||||||
|
dataset_name = 'test_episode'
|
||||||
|
max_timesteps = 100
|
||||||
|
camera_names = ['cam1', 'cam2']
|
||||||
|
overwrite = False
|
||||||
|
|
||||||
|
# 清理旧的测试数据
|
||||||
|
if os.path.exists(dataset_dir):
|
||||||
|
shutil.rmtree(dataset_dir)
|
||||||
|
|
||||||
|
# 创建 DataCollector 实例并调用 create_dataset_dir
|
||||||
|
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, overwrite)
|
||||||
|
|
||||||
|
# 检查目录是否按预期创建
|
||||||
|
date_str = datetime.now().strftime("%Y%m%d")
|
||||||
|
expected_dir = os.path.join(dataset_dir, date_str)
|
||||||
|
assert os.path.exists(expected_dir), f"Expected directory {expected_dir} does not exist."
|
||||||
|
|
||||||
|
# 检查文件名是否按预期递增
|
||||||
|
expected_file = os.path.join(expected_dir, dataset_name + '.hdf5')
|
||||||
|
assert collector.dataset_path == expected_file, f"Expected file path {expected_file}, but got {collector.dataset_path}"
|
||||||
|
|
||||||
|
# 再次调用 create_dataset_dir,检查文件名是否递增
|
||||||
|
# collector.create_dataset_dir()
|
||||||
|
expected_file_incremented = os.path.join(expected_dir, dataset_name + '_1.hdf5')
|
||||||
|
assert collector.dataset_path == expected_file_incremented, f"Expected file path {expected_file_incremented}, but got {collector.dataset_path}"
|
||||||
|
|
||||||
|
print("All tests passed.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_create_dataset_dir()
|
||||||
105
realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py
Normal file
105
realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
import socket
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 设置日志级别
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class test_udp():
|
||||||
|
def __init__(self):
|
||||||
|
arm_ip = '192.168.1.19'
|
||||||
|
arm_port = 8080
|
||||||
|
self.arm =socket.socket()
|
||||||
|
self.arm.connect((arm_ip, arm_port))
|
||||||
|
set_udp = {"command":"set_realtime_push","cycle":1,"enable":True,"port":8090,"ip":"192.168.1.101","custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
|
||||||
|
self.arm.send(json.dumps(set_udp).encode('utf-8'))
|
||||||
|
state = self.arm.recv(1024)
|
||||||
|
|
||||||
|
logging.info(f"Send data to {arm_ip}:{arm_port}: {state}")
|
||||||
|
|
||||||
|
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
|
||||||
|
# 设置套接字选项,允许端口复用
|
||||||
|
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
|
||||||
|
local_ip = "192.168.1.101"
|
||||||
|
local_port = 8090
|
||||||
|
self.udp_socket.bind((local_ip, local_port))
|
||||||
|
self.BUFFER_SIZE = 1024
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def set_udp(self):
|
||||||
|
|
||||||
|
while True:
|
||||||
|
start_time = time.time()
|
||||||
|
data, addr = self.udp_socket.recvfrom(self.BUFFER_SIZE)
|
||||||
|
# 将接收到的UDP数据解码并解析为JSON
|
||||||
|
data = json.loads(data.decode('utf-8'))
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Received data {data}")
|
||||||
|
|
||||||
|
udp_socket.close()
|
||||||
|
|
||||||
|
|
||||||
|
def collect_arm_data(arm_id, queue, event):
|
||||||
|
while True:
|
||||||
|
data = f"Arm {arm_id} data {random.random()}"
|
||||||
|
queue.put((arm_id, data))
|
||||||
|
event.set()
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def collect_camera_data(camera_id, queue, event):
|
||||||
|
while True:
|
||||||
|
data = f"Camera {camera_id} data {random.random()}"
|
||||||
|
queue.put((camera_id, data))
|
||||||
|
event.set()
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
arm_queues = [multiprocessing.Queue() for _ in range(4)]
|
||||||
|
camera_queues = [multiprocessing.Queue() for _ in range(4)]
|
||||||
|
arm_events = [multiprocessing.Event() for _ in range(4)]
|
||||||
|
camera_events = [multiprocessing.Event() for _ in range(4)]
|
||||||
|
|
||||||
|
arm_processes = [multiprocessing.Process(target=collect_arm_data, args=(i, arm_queues[i], arm_events[i])) for i in range(4)]
|
||||||
|
camera_processes = [multiprocessing.Process(target=collect_camera_data, args=(i, camera_queues[i], camera_events[i])) for i in range(4)]
|
||||||
|
|
||||||
|
for p in arm_processes + camera_processes:
|
||||||
|
p.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
for event in arm_events + camera_events:
|
||||||
|
event.wait()
|
||||||
|
|
||||||
|
for i in range(4):
|
||||||
|
if not arm_queues[i].empty():
|
||||||
|
arm_id, arm_data = arm_queues[i].get()
|
||||||
|
print(f"Received from Arm {arm_id}: {arm_data}")
|
||||||
|
arm_events[i].clear()
|
||||||
|
|
||||||
|
if not camera_queues[i].empty():
|
||||||
|
camera_id, camera_data = camera_queues[i].get()
|
||||||
|
print(f"Received from Camera {camera_id}: {camera_data}")
|
||||||
|
camera_events[i].clear()
|
||||||
|
|
||||||
|
time.sleep(0.1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
for p in arm_processes + camera_processes:
|
||||||
|
p.terminate()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# test_udp = test_udp()
|
||||||
|
# test_udp.set_udp()
|
||||||
4
realman_src/realman_aloha/shadow_rm_robot/.gitignore
vendored
Normal file
4
realman_src/realman_aloha/shadow_rm_robot/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pt
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user