Compare commits

...

14 Commits
main ... depth

Author SHA1 Message Date
c28e774234 optimaze the speed of end pose control 2025-06-13 20:17:31 +08:00
80b1a97e4c change opencv to realsense camera 2025-06-12 17:56:21 +08:00
f4fec8f51c change pose control api 2025-06-11 16:17:39 +08:00
f4f82c916f some bug still 2025-06-11 15:20:14 +08:00
ecbe154709 no change 2025-06-09 16:24:00 +08:00
d00c154db9 update state 2025-06-09 16:23:09 +08:00
55f284b306 mix control fix bug 2025-06-09 10:58:28 +08:00
cf8df17d3a add realman shadow src 2025-06-07 11:29:43 +08:00
e079566597 xbox controller demo 2025-06-07 11:22:05 +08:00
83d6419d70 手柄控制第一次提交 2025-06-05 21:56:52 +08:00
a0ec9e1cb1 single arm test 2025-06-05 15:50:26 +08:00
3eede4447d dual arm test 2025-06-05 15:50:18 +08:00
9c6a7d9701 new md 2025-06-05 15:50:11 +08:00
7b201773f3 single arm test 2025-06-05 15:49:57 +08:00
117 changed files with 16384 additions and 2 deletions

View File

@@ -58,7 +58,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f
log_dt("dt", dt_s)
# TODO(aliberts): move robot-specific logs logic in robot.print_logs()
if not robot.robot_type.startswith("stretch"):
if not robot.robot_type.startswith(("stretch", "realman")):
for name in robot.leader_arms:
key = f"read_leader_{name}_pos_dt_s"
if key in robot.logs:

View File

@@ -39,3 +39,12 @@ class FeetechMotorsBusConfig(MotorsBusConfig):
port: str
motors: dict[str, tuple[int, str]]
mock: bool = False
@MotorsBusConfig.register_subclass("realman")
@dataclass
class RealmanMotorsBusConfig(MotorsBusConfig):
ip: str
port: int
motors: dict[str, tuple[int, str]]
init_joint: dict[str, list]

View File

@@ -0,0 +1,128 @@
import time
from typing import Dict
from lerobot.common.robot_devices.motors.configs import RealmanMotorsBusConfig
from Robotic_Arm.rm_robot_interface import *
class RealmanMotorsBus:
"""
对Realman SDK的二次封装
"""
def __init__(self,
config: RealmanMotorsBusConfig):
self.rmarm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
self.handle = self.rmarm.rm_create_robot_arm(config.ip, config.port)
self.motors = config.motors
self.init_joint_position = config.init_joint['joint'] # [6 joints + 1 gripper]
self.safe_disable_position = config.init_joint['joint']
@property
def motor_names(self) -> list[str]:
return list(self.motors.keys())
@property
def motor_models(self) -> list[str]:
return [model for _, model in self.motors.values()]
@property
def motor_indices(self) -> list[int]:
return [idx for idx, _ in self.motors.values()]
def connect(self, enable=True) -> bool:
'''
使能机械臂并检测使能状态,尝试5s,如果使能超时则退出程序
'''
enable_flag = False
loop_flag = False
# 设置超时时间(秒)
timeout = 5
# 记录进入循环前的时间
start_time = time.time()
elapsed_time_flag = False
while not loop_flag:
elapsed_time = time.time() - start_time
print("--------------------")
if enable:
# 获取机械臂状态
ret = self.rmarm.rm_get_current_arm_state()
if ret[0] == 0: # 成功获取状态
enable_flag = True
else:
enable_flag = False
# 断开所有连接,销毁线程
RoboticArm.rm_destory()
print("使能状态:", enable_flag)
print("--------------------")
if(enable_flag == enable):
loop_flag = True
enable_flag = True
else:
loop_flag = False
enable_flag = False
# 检查是否超过超时时间
if elapsed_time > timeout:
print("超时....")
elapsed_time_flag = True
enable_flag = True
break
time.sleep(1)
resp = enable_flag
print(f"Returning response: {resp}")
return resp
def motor_names(self):
return
def set_calibration(self):
return
def revert_calibration(self):
return
def apply_calibration(self):
"""
移动到初始位置
"""
self.write(target_joint=self.init_joint_position)
def write(self, target_joint:list):
# self.rmarm.rm_movej(target_joint[:-1], 50, 0, 0, 1)
self.rmarm.rm_movej_follow(target_joint[:-1])
self.rmarm.rm_set_gripper_position(target_joint[-1], block=False, timeout=2)
def write_endpose(self, target_endpose: list, gripper: int):
self.rmarm.rm_movej_p(target_endpose, 50, 0, 0, 1)
self.rmarm.rm_set_gripper_position(gripper, block=False, timeout=2)
def read(self) -> Dict:
"""
- 机械臂关节消息,单位1度;[-1, 1]
- 机械臂夹爪消息,[-1, 1]
"""
joint_msg = self.rmarm.rm_get_current_arm_state()[1]
joint_state = joint_msg['joint']
gripper_msg = self.rmarm.rm_get_gripper_state()[1]
gripper_state = gripper_msg['actpos']
return {
"joint_1": joint_state[0]/180,
"joint_2": joint_state[1]/180,
"joint_3": joint_state[2]/180,
"joint_4": joint_state[3]/180,
"joint_5": joint_state[4]/180,
"joint_6": joint_state[5]/180,
"gripper": (gripper_state-500)/500
}
def safe_disconnect(self):
"""
Move to safe disconnect position
"""
self.write(target_joint=self.safe_disable_position)
# 断开所有连接,销毁线程
RoboticArm.rm_destory()

View File

@@ -44,6 +44,11 @@ def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig
motors_buses[key] = FeetechMotorsBus(cfg)
elif cfg.type == "realman":
from lerobot.common.robot_devices.motors.realman import RealmanMotorsBus
motors_buses[key] = RealmanMotorsBus(cfg)
else:
raise ValueError(f"The motor type '{cfg.type}' is not valid.")
@@ -65,3 +70,7 @@ def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus:
else:
raise ValueError(f"The motor type '{motor_type}' is not valid.")
def get_motor_names(arm: dict[str, MotorsBus]) -> list:
return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors]

View File

@@ -27,6 +27,7 @@ from lerobot.common.robot_devices.motors.configs import (
DynamixelMotorsBusConfig,
FeetechMotorsBusConfig,
MotorsBusConfig,
RealmanMotorsBusConfig
)
@@ -674,3 +675,91 @@ class LeKiwiRobotConfig(RobotConfig):
)
mock: bool = False
@RobotConfig.register_subclass("realman")
@dataclass
class RealmanRobotConfig(RobotConfig):
inference_time: bool = False
max_gripper: int = 990
min_gripper: int = 10
servo_config_file: str = "/home/maic/LYT/lerobot/realman_src/realman_aloha/shadow_rm_robot/config/servo_arm.yaml"
left_follower_arm: dict[str, MotorsBusConfig] = field(
default_factory=lambda: {
"main": RealmanMotorsBusConfig(
ip = "192.168.3.18",
port = 8080,
motors={
# name: (index, model)
"joint_1": [1, "realman"],
"joint_2": [2, "realman"],
"joint_3": [3, "realman"],
"joint_4": [4, "realman"],
"joint_5": [5, "realman"],
"joint_6": [6, "realman"],
"gripper": [7, "realman"],
},
init_joint = {'joint': [-90, 90, 90, -90, -90, 90, 1000]}
)
}
)
cameras: dict[str, CameraConfig] = field(
default_factory=lambda: {
# "one": OpenCVCameraConfig(
# camera_index=4,
# fps=30,
# width=640,
# height=480,
# ),
"left": IntelRealSenseCameraConfig(
serial_number="153122077516",
fps=30,
width=640,
height=480,
use_depth=False
),
"right": IntelRealSenseCameraConfig(
serial_number="405622075165",
fps=30,
width=640,
height=480,
use_depth=False
),
"front": IntelRealSenseCameraConfig(
serial_number="145422072751",
fps=30,
width=640,
height=480,
use_depth=False
),
"high": IntelRealSenseCameraConfig(
serial_number="145422072193",
fps=30,
width=640,
height=480,
use_depth=False
),
}
)
# right_follower_arm: dict[str, MotorsBusConfig] = field(
# default_factory=lambda: {
# "main": RealmanMotorsBusConfig(
# ip = "192.168.3.19",
# port = 8080,
# motors={
# # name: (index, model)
# "joint_1": [1, "realman"],
# "joint_2": [2, "realman"],
# "joint_3": [3, "realman"],
# "joint_4": [4, "realman"],
# "joint_5": [5, "realman"],
# "joint_6": [6, "realman"],
# "gripper": (7, "realman"),
# },
# )
# }
# )

View File

@@ -0,0 +1,305 @@
"""
Teleoperation Realman with a PS5 controller and
"""
import time
import torch
import numpy as np
from dataclasses import dataclass, field, replace
from collections import deque
from lerobot.common.robot_devices.teleop.gamepad import HybridController
from lerobot.common.robot_devices.motors.utils import get_motor_names, make_motors_buses_from_configs
from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs
from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError
from lerobot.common.robot_devices.robots.configs import RealmanRobotConfig
class RealmanRobot:
def __init__(self, config: RealmanRobotConfig | None = None, **kwargs):
if config is None:
config = RealmanRobotConfig()
# Overwrite config arguments using kwargs
self.config = replace(config, **kwargs)
self.robot_type = self.config.type
self.inference_time = self.config.inference_time # if it is inference time
# build cameras
self.cameras = make_cameras_from_configs(self.config.cameras)
# build realman motors
self.piper_motors = make_motors_buses_from_configs(self.config.left_follower_arm)
self.arm = self.piper_motors['main']
self.arm.rmarm.rm_movej(self.arm.init_joint_position[:-1], 50, 0, 0, 1)
time.sleep(2)
ret = self.arm.rmarm.rm_get_current_arm_state()
init_pose = ret[1]['pose']
# build init teleop info
self.init_info = {
'init_joint': self.arm.init_joint_position,
'init_pose': init_pose,
'max_gripper': config.max_gripper,
'min_gripper': config.min_gripper,
'servo_config_file': config.servo_config_file
}
# build state-action cache
self.joint_queue = deque(maxlen=2)
# build gamepad teleop
if not self.inference_time:
self.teleop = HybridController(self.init_info)
else:
self.teleop = None
self.logs = {}
self.is_connected = False
@property
def camera_features(self) -> dict:
cam_ft = {}
for cam_key, cam in self.cameras.items():
key = f"observation.images.{cam_key}"
cam_ft[key] = {
"shape": (cam.height, cam.width, cam.channels),
"names": ["height", "width", "channels"],
"info": None,
}
return cam_ft
@property
def motor_features(self) -> dict:
action_names = get_motor_names(self.piper_motors)
state_names = get_motor_names(self.piper_motors)
return {
"action": {
"dtype": "float32",
"shape": (len(action_names),),
"names": action_names,
},
"observation.state": {
"dtype": "float32",
"shape": (len(state_names),),
"names": state_names,
},
}
@property
def has_camera(self):
return len(self.cameras) > 0
@property
def num_cameras(self):
return len(self.cameras)
def connect(self) -> None:
"""Connect RealmanArm and cameras"""
if self.is_connected:
raise RobotDeviceAlreadyConnectedError(
"RealmanArm is already connected. Do not run `robot.connect()` twice."
)
# connect RealmanArm
self.arm.connect(enable=True)
print("RealmanArm conneted")
# connect cameras
for name in self.cameras:
self.cameras[name].connect()
self.is_connected = self.is_connected and self.cameras[name].is_connected
print(f"camera {name} conneted")
print("All connected")
self.is_connected = True
self.run_calibration()
def disconnect(self) -> None:
"""move to home position, disenable piper and cameras"""
# move piper to home position, disable
if not self.inference_time:
self.teleop.stop()
# disconnect piper
self.arm.safe_disconnect()
print("RealmanArm disable after 5 seconds")
time.sleep(5)
self.arm.connect(enable=False)
# disconnect cameras
if len(self.cameras) > 0:
for cam in self.cameras.values():
cam.disconnect()
self.is_connected = False
def run_calibration(self):
"""move piper to the home position"""
if not self.is_connected:
raise ConnectionError()
self.arm.apply_calibration()
if not self.inference_time:
self.teleop.reset()
def teleop_step(
self, record_data=False
) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
if not self.is_connected:
raise ConnectionError()
if self.teleop is None and self.inference_time:
self.teleop = HybridController(self.init_info)
# read target pose state as
before_read_t = time.perf_counter()
state = self.arm.read() # read current joint position from robot
action = self.teleop.get_action() # target joint position and pose end pos from gamepad
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
if action['control_mode'] == 'joint':
# 关节控制模式(主模式)
ret = self.arm.rmarm.rm_get_current_arm_state()
current_pose = ret[1]['pose']
self.teleop.update_endpose_state(current_pose)
target_joints = action['joint_angles']
# do action
before_write_t = time.perf_counter()
self.joint_queue.append(list(self.arm.read().values()))
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
else:
# 末端位姿控制模式
target_pose = [
action['end_pose']['X'], # X (m)
action['end_pose']['Y'], # Y (m)
action['end_pose']['Z'], # Z (m)
action['end_pose']['RX'], # RX (rad)
action['end_pose']['RY'], # RY (rad)
action['end_pose']['RZ'] # RZ (rad)
]
# do action
before_write_t = time.perf_counter()
# result = self.arm.rmarm.rm_movej_p(target_pose, 100, 0, 0, 0)
self.arm.rmarm.rm_movep_follow(target_pose)
self.arm.rmarm.rm_set_gripper_position(action['gripper'], False, 2)
ret = self.arm.rmarm.rm_get_current_arm_state()
target_joints = ret[1].get('joint', self.arm.init_joint_position)
self.joint_queue.append(list(self.arm.read().values()))
self.teleop.update_joint_state(target_joints)
self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t
# print('-'*80)
# print('mode: ', action['control_mode'])
# print('state: ', list(state.values()))
# print('action: ', target_joints)
# print('cache[0]: ', self.joint_queue[0])
# print('cache[-1]: ', self.joint_queue[-1])
# print('time: ', time.perf_counter() - before_write_t)
# print('-'*80)
# time.sleep(1)
if not record_data:
return
state = torch.as_tensor(list(self.joint_queue[0]), dtype=torch.float32)
action = torch.as_tensor(list(self.joint_queue[-1]), dtype=torch.float32)
# Capture images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries
obs_dict, action_dict = {}, {}
obs_dict["observation.state"] = state
action_dict["action"] = action
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict, action_dict
def send_action(self, action: torch.Tensor) -> torch.Tensor:
"""Write the predicted actions from policy to the motors"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"Piper is not connected. You need to run `robot.connect()`."
)
# send to motors, torch to list
target_joints = action.tolist()
len_joint = len(target_joints) - 1
target_joints = [target_joints[i]*180 for i in range(len_joint)] + [target_joints[-1]]
target_joints[-1] = int(target_joints[-1]*500 + 500)
self.arm.write(target_joints)
return action
def capture_observation(self) -> dict:
"""capture current images and joint positions"""
if not self.is_connected:
raise RobotDeviceNotConnectedError(
"Piper is not connected. You need to run `robot.connect()`."
)
# read current joint positions
before_read_t = time.perf_counter()
state = self.arm.read() # 6 joints + 1 gripper
self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t
state = torch.as_tensor(list(state.values()), dtype=torch.float32)
# read images from cameras
images = {}
for name in self.cameras:
before_camread_t = time.perf_counter()
images[name] = self.cameras[name].async_read()
images[name] = torch.from_numpy(images[name])
self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"]
self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t
# Populate output dictionnaries and format to pytorch
obs_dict = {}
obs_dict["observation.state"] = state
for name in self.cameras:
obs_dict[f"observation.images.{name}"] = images[name]
return obs_dict
def teleop_safety_stop(self):
""" move to home position after record one episode """
self.run_calibration()
def __del__(self):
if self.is_connected:
self.disconnect()
if not self.inference_time:
self.teleop.stop()
if __name__ == '__main__':
robot = RealmanRobot()
robot.connect()
# robot.run_calibration()
while True:
robot.teleop_step(record_data=True)
robot.capture_observation()
dummy_action = torch.Tensor([-0.40586111280653214, 0.5522833506266276, 0.4998166826036241, -0.3539944542778863, -0.524433347913954, 0.9064999898274739, 0.482])
robot.send_action(dummy_action)
robot.disconnect()
print('ok')

View File

@@ -25,6 +25,7 @@ from lerobot.common.robot_devices.robots.configs import (
So100RobotConfig,
So101RobotConfig,
StretchRobotConfig,
RealmanRobotConfig
)
@@ -65,6 +66,9 @@ def make_robot_config(robot_type: str, **kwargs) -> RobotConfig:
return StretchRobotConfig(**kwargs)
elif robot_type == "lekiwi":
return LeKiwiRobotConfig(**kwargs)
elif robot_type == 'realman':
return RealmanRobotConfig(**kwargs)
else:
raise ValueError(f"Robot type '{robot_type}' is not available.")
@@ -78,6 +82,12 @@ def make_robot_from_config(config: RobotConfig):
from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator
return MobileManipulator(config)
elif isinstance(config, RealmanRobotConfig):
from lerobot.common.robot_devices.robots.realman import RealmanRobot
return RealmanRobot(config)
else:
from lerobot.common.robot_devices.robots.stretch import StretchRobot

View File

@@ -0,0 +1,461 @@
import pygame
import threading
import time
import serial
import binascii
import logging
import yaml
from typing import Dict
from Robotic_Arm.rm_robot_interface import *
class ServoArm:
def __init__(self, config_file="config.yaml"):
"""初始化机械臂的串口连接并发送初始数据。
Args:
config_file (str): 配置文件的路径。
"""
self.config = self._load_config(config_file)
self.port = self.config["port"]
self.baudrate = self.config["baudrate"]
self.hex_data = self.config["hex_data"]
self.arm_axis = self.config.get("arm_axis", 7)
try:
self.serial_conn = serial.Serial(self.port, self.baudrate, timeout=0)
self.bytes_to_send = binascii.unhexlify(self.hex_data.replace(" ", ""))
self.serial_conn.write(self.bytes_to_send)
time.sleep(1)
self.connected = True
logging.info(f"串口连接成功: {self.port}")
except Exception as e:
logging.error(f"串口连接失败: {e}")
self.connected = False
def _load_config(self, config_file):
"""加载配置文件。
Args:
config_file (str): 配置文件的路径。
Returns:
dict: 配置文件内容。
"""
try:
with open(config_file, "r") as file:
config = yaml.safe_load(file)
return config
except Exception as e:
logging.error(f"配置文件加载失败: {e}")
# 返回默认配置
return {
"port": "/dev/ttyUSB0",
"baudrate": 460800,
"hex_data": "55 AA 02 00 00 67",
"arm_axis": 6
}
def _bytes_to_signed_int(self, byte_data):
"""将字节数据转换为有符号整数。
Args:
byte_data (bytes): 字节数据。
Returns:
int: 有符号整数。
"""
return int.from_bytes(byte_data, byteorder="little", signed=True)
def _parse_joint_data(self, hex_received):
"""解析接收到的十六进制数据并提取关节数据。
Args:
hex_received (str): 接收到的十六进制字符串数据。
Returns:
dict: 解析后的关节数据。
"""
logging.debug(f"hex_received: {hex_received}")
joints = {}
for i in range(self.arm_axis):
start = 14 + i * 10
end = start + 8
joint_hex = hex_received[start:end]
joint_byte_data = bytearray.fromhex(joint_hex)
joint_value = self._bytes_to_signed_int(joint_byte_data) / 10000.0
joints[f"joint_{i+1}"] = joint_value
grasp_start = 14 + self.arm_axis*10
grasp_hex = hex_received[grasp_start:grasp_start+8]
grasp_byte_data = bytearray.fromhex(grasp_hex)
# 夹爪进行归一化处理
grasp_value = self._bytes_to_signed_int(grasp_byte_data)/1000
joints["grasp"] = grasp_value
return joints
def get_joint_actions(self):
"""从串口读取数据并解析关节动作。
Returns:
dict: 包含关节数据的字典。
"""
if not self.connected:
return {}
try:
self.serial_conn.write(self.bytes_to_send)
bytes_received = self.serial_conn.read(self.serial_conn.inWaiting())
if len(bytes_received) == 0:
return {}
hex_received = binascii.hexlify(bytes_received).decode("utf-8").upper()
actions = self._parse_joint_data(hex_received)
return actions
except Exception as e:
logging.error(f"读取串口数据错误: {e}")
return {}
def set_gripper_action(self, action):
"""设置夹爪动作。
Args:
action (int): 夹爪动作值。
"""
if not self.connected:
return
try:
action = int(action * 1000)
action_bytes = action.to_bytes(4, byteorder="little", signed=True)
self.bytes_to_send = self.bytes_to_send[:74] + action_bytes + self.bytes_to_send[78:]
except Exception as e:
logging.error(f"设置夹爪动作错误: {e}")
def close(self):
"""关闭串口连接"""
if self.connected and hasattr(self, 'serial_conn'):
self.serial_conn.close()
self.connected = False
logging.info("串口连接已关闭")
class HybridController:
def __init__(self, init_info):
# 初始化pygame和手柄
pygame.init()
pygame.joystick.init()
# 检查是否有连接的手柄
if pygame.joystick.get_count() == 0:
raise Exception("未检测到手柄")
# 初始化手柄
self.joystick = pygame.joystick.Joystick(0)
self.joystick.init()
# 摇杆死区
self.deadzone = 0.15
# 控制模式: True为关节控制主模式False为末端控制
self.joint_control_mode = True
# 精细控制模式
self.fine_control_mode = False
# 初始化末端姿态和关节角度
self.init_joint = init_info['init_joint']
self.init_pose = init_info.get('init_pose', [0]*6)
self.max_gripper = init_info['max_gripper']
self.min_gripper = init_info['min_gripper']
servo_config_file = init_info['servo_config_file']
self.joint = self.init_joint.copy()
self.pose = self.init_pose.copy()
self.pose_speeds = [0.0] * 6
self.joint_speeds = [0.0] * 6
self.tozero = False
# 主臂关节状态
self.master_joint_actions = {}
self.use_master_arm = False
# 末端位姿限制
self.pose_limits = [
(-0.800, 0.800), # X (m)
(-0.800, 0.800), # Y (m)
(-0.800, 0.800), # Z (m)
(-3.14, 3.14), # RX (rad)
(-3.14, 3.14), # RY (rad)
(-3.14, 3.14) # RZ (rad)
]
# 关节角度限制 (度)
self.joint_limits = [
(-180, 180), # joint 1
(-180, 180), # joint 2
(-180, 180), # joint 3
(-180, 180), # joint 4
(-180, 180), # joint 5
(-180, 180) # joint 6
]
# 控制参数
self.linear_step = 0.002 # 线性移动步长(m)
self.angular_step = 0.01 # 角度步长(rad)
# 夹爪状态和速度
self.gripper_speed = 10
self.gripper = self.min_gripper
# 初始化串口通信(主臂关节状态获取)
self.servo_arm = None
if servo_config_file:
try:
self.servo_arm = ServoArm(servo_config_file)
self.use_master_arm = True
logging.info("串口主臂连接成功,启用主从控制模式")
except Exception as e:
logging.error(f"串口主臂连接失败: {e}")
self.use_master_arm = False
# 启动更新线程
self.running = True
self.thread = threading.Thread(target=self.update_controller)
self.thread.start()
print("混合控制器已启动")
print("主控制模式: 关节控制")
if self.use_master_arm:
print("主从控制: 启用")
print("Back按钮: 切换控制模式(关节/末端)")
print("L3按钮: 切换精细控制模式")
print("Start按钮: 重置到初始位置")
def _apply_nonlinear_mapping(self, value):
"""应用非线性映射以提高控制精度"""
sign = 1 if value >= 0 else -1
return sign * (abs(value) ** 2)
def _normalize_angle(self, angle):
"""将角度归一化到[-π, π]范围内"""
import math
while angle > math.pi:
angle -= 2 * math.pi
while angle < -math.pi:
angle += 2 * math.pi
return angle
def update_controller(self):
while self.running:
try:
pygame.event.pump()
except Exception as e:
print(f"控制器错误: {e}")
self.stop()
continue
# 检查控制模式切换 (Back按钮)
if self.joystick.get_button(6): # Back按钮
self.joint_control_mode = not self.joint_control_mode
mode_str = "关节控制" if self.joint_control_mode else "末端位姿控制"
print(f"切换到{mode_str}模式")
time.sleep(0.3) # 防止多次触发
# 检查精细控制模式切换 (L3按钮)
if self.joystick.get_button(10): # L3按钮
self.fine_control_mode = not self.fine_control_mode
print(f"切换到{'精细' if self.fine_control_mode else '普通'}控制模式")
time.sleep(0.3) # 防止多次触发
# 检查重置按钮 (Start按钮)
if self.joystick.get_button(7): # Start按钮
print("重置机械臂到初始位置...")
# print("init_joint", self.init_joint.copy())
self.tozero = True
self.joint = self.init_joint.copy()
self.pose = self.init_pose.copy()
self.pose_speeds = [0.0] * 6
self.joint_speeds = [0.0] * 6
self.gripper_speed = 10
self.gripper = self.min_gripper
print("机械臂已重置到初始位置")
time.sleep(0.3) # 防止多次触发
# 从串口获取主臂关节状态
if self.servo_arm and self.servo_arm.connected:
try:
self.master_joint_actions = self.servo_arm.get_joint_actions()
if self.master_joint_actions:
logging.debug(f"主臂关节状态: {self.master_joint_actions}")
except Exception as e:
logging.error(f"获取主臂状态错误: {e}")
self.master_joint_actions = {}
# print(self.master_joint_actions)
# 根据控制模式更新相应的控制逻辑
if self.joint_control_mode:
# 关节控制模式下优先使用主臂数据Xbox作为辅助
self.update_joint_control()
else:
# 末端控制模式使用Xbox控制
self.update_end_pose()
time.sleep(0.02)
# print('gripper:', self.gripper)
def update_joint_control(self):
"""更新关节角度控制 - 优先使用主臂数据"""
if self.use_master_arm and self.master_joint_actions:
# 主从控制模式:直接使用主臂的关节角度
try:
# 将主臂关节角度映射到从臂
for i in range(6): # 假设只有6个关节需要控制
joint_key = f"joint_{i+1}"
if joint_key in self.master_joint_actions:
# 直接使用主臂的关节角度(已经是度数)
self.joint[i] = self.master_joint_actions[joint_key]
# 应用关节限制
min_val, max_val = self.joint_limits[i]
self.joint[i] = max(min_val, min(max_val, self.joint[i]))
# print(self.joint)
logging.debug(f"主臂关节映射到从臂: {self.joint[:6]}")
except Exception as e:
logging.error(f"主臂数据映射错误: {e}")
# 如果有主臂夹爪数据,使用主臂夹爪状态
if self.use_master_arm and "grasp" in self.master_joint_actions:
self.gripper = self.master_joint_actions["grasp"] * 1000
self.joint[-1] = self.gripper
def update_end_pose(self):
"""更新末端位姿控制"""
# 根据控制模式调整步长
current_linear_step = self.linear_step * (0.1 if self.fine_control_mode else 1.0)
current_angular_step = self.angular_step * (0.1 if self.fine_control_mode else 1.0)
# 方向键控制XY
hat = self.joystick.get_hat(0)
hat_up = hat[1] == 1 # Y+
hat_down = hat[1] == -1 # Y-
hat_left = hat[0] == -1 # X-
hat_right = hat[0] == 1 # X+
# 右摇杆控制Z
right_y_raw = -self.joystick.get_axis(4)
# 左摇杆控制RZ
left_y_raw = -self.joystick.get_axis(1)
# 应用死区
right_y = 0.0 if abs(right_y_raw) < self.deadzone else right_y_raw
left_y = 0.0 if abs(left_y_raw) < self.deadzone else left_y_raw
# 计算各轴速度
self.pose_speeds[1] = current_linear_step if hat_up else (-current_linear_step if hat_down else 0.0) # Y
self.pose_speeds[0] = -current_linear_step if hat_left else (current_linear_step if hat_right else 0.0) # X
# 设置Z速度右摇杆Y轴控制
z_mapping = self._apply_nonlinear_mapping(right_y)
self.pose_speeds[2] = z_mapping * current_linear_step # Z
# L1/R1控制RX旋转
LB = self.joystick.get_button(4) # RX-
RB = self.joystick.get_button(5) # RX+
self.pose_speeds[3] = (-current_angular_step if LB else (current_angular_step if RB else 0.0))
# △/□控制RY旋转
triangle = self.joystick.get_button(2) # RY+
square = self.joystick.get_button(3) # RY-
self.pose_speeds[4] = (current_angular_step if triangle else (-current_angular_step if square else 0.0))
# 左摇杆Y轴控制RZ旋转
rz_mapping = self._apply_nonlinear_mapping(left_y)
self.pose_speeds[5] = rz_mapping * current_angular_step * 2 # RZ
# 夹爪控制(圈/叉)
circle = self.joystick.get_button(1) # 夹爪开
cross = self.joystick.get_button(0) # 夹爪关
if circle:
self.gripper = min(self.max_gripper, self.gripper + self.gripper_speed)
elif cross:
self.gripper = max(self.min_gripper, self.gripper - self.gripper_speed)
# 更新末端位姿
for i in range(6):
self.pose[i] += self.pose_speeds[i]
# 角度归一化处理
for i in range(3, 6):
self.pose[i] = self._normalize_angle(self.pose[i])
def update_joint_state(self, joint):
self.joint = joint
# self.tozero = False
def update_endpose_state(self, end_pose):
self.pose = end_pose
# self.tozero = False
def update_tozero_state(self, tozero):
self.tozero = tozero
def get_action(self) -> Dict:
"""获取当前控制命令"""
return {
'control_mode': 'joint' if self.joint_control_mode else 'end_pose',
'use_master_arm': self.use_master_arm,
'master_joint_actions': self.master_joint_actions,
'end_pose': {
'X': self.pose[0],
'Y': self.pose[1],
'Z': self.pose[2],
'RX': self.pose[3],
'RY': self.pose[4],
'RZ': self.pose[5],
},
'joint_angles': self.joint,
'gripper': int(self.gripper),
'tozero': self.tozero
}
def stop(self):
"""停止控制器"""
self.running = False
if self.thread.is_alive():
self.thread.join()
if self.servo_arm:
self.servo_arm.close()
pygame.quit()
print("混合控制器已退出")
def reset(self):
"""重置到初始状态"""
self.joint = self.init_joint.copy()
self.pose = self.init_pose.copy()
self.pose_speeds = [0.0] * 6
self.joint_speeds = [0.0] * 6
self.gripper_speed = 10
self.gripper = self.min_gripper
print("已重置到初始状态")
# 使用示例
if __name__ == "__main__":
# 初始化睿尔曼机械臂
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
# 创建机械臂连接
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
print(f"机械臂连接ID: {handle.id}")
init_pose = arm.rm_get_current_arm_state()[1]['pose']
with open('/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/realman_mix.yaml', "r") as file:
config = yaml.safe_load(file)
config['init_pose'] = init_pose
arm_controller = HybridController(config)
try:
while True:
print(arm_controller.get_action())
time.sleep(0.1)
except KeyboardInterrupt:
arm_controller.stop()

View File

@@ -0,0 +1,4 @@
init_joint: [-90, 90, 90, -90, -90, 90]
max_gripper: 990
min_gripper: 10
servo_config_file: "/home/maic/LYT/lerobot/lerobot/common/robot_devices/teleop/servo_arm.yaml"

View File

@@ -0,0 +1,5 @@
port: /dev/ttyUSB0
right_port: /dev/ttyUSB1
baudrate: 460800
hex_data: "55 AA 02 00 00 67"
arm_axis: 6

View File

@@ -273,7 +273,6 @@ def record(
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
@@ -290,6 +289,9 @@ def record(
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
# import pdb
# pdb.set_trace()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:

125
realman.md Normal file
View File

@@ -0,0 +1,125 @@
# Install
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
```bash
conda create -y -n lerobot python=3.10
conda activate lerobot
```
Install 🤗 LeRobot:
```bash
pip install -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
# pip uninstall numpy
# pip install numpy==1.26.0
# pip install pynput
```
/!\ For Linux only, ffmpeg and opencv requires conda install for now. Run this exact sequence of commands:
```bash
conda install ffmpeg=7.1.1 -c conda-forge
# pip uninstall opencv-python
# conda install "opencv>=4.10.0"
```
Install Realman SDK:
```bash
pip install Robotic_Arm==1.0.4.1
pip install pygame
```
# piper集成lerobot
见lerobot_piper_tutorial/1. 🤗 LeRobot新增机械臂的一般流程.pdf
# Teleoperate
```bash
cd piper_scripts/
bash can_activate.sh can0 1000000
cd ..
python lerobot/scripts/control_robot.py \
--robot.type=piper \
--robot.inference_time=false \
--control.type=teleoperate
```
# Record
Set dataset root path
```bash
HF_USER=$PWD/data
echo $HF_USER
```
```bash
python lerobot/scripts/control_robot.py \
--robot.type=piper \
--robot.inference_time=false \
--control.type=record \
--control.fps=30 \
--control.single_task="move" \
--control.repo_id=${HF_USER}/test \
--control.num_episodes=2 \
--control.warmup_time_s=2 \
--control.episode_time_s=10 \
--control.reset_time_s=10 \
--control.play_sounds=true \
--control.push_to_hub=false
```
Press right arrow -> at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording.
Press left arrow <- at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it.
Press escape ESC at any time during episode recording to end the session early and go straight to video encoding and dataset uploading.
# visualize
```bash
python lerobot/scripts/visualize_dataset.py \
--repo-id ${HF_USER}/test \
--episode-index 0
```
# Replay
```bash
python lerobot/scripts/control_robot.py \
--robot.type=piper \
--robot.inference_time=false \
--control.type=replay \
--control.fps=30 \
--control.repo_id=${HF_USER}/test \
--control.episode=0
```
# Caution
1. In lerobots/common/datasets/video_utils, the vcodec is set to **libopenh264**, please find your vcodec by **ffmpeg -codecs**
# Train
具体的训练流程见lerobot_piper_tutorial/2. 🤗 AutoDL训练.pdf
```bash
python lerobot/scripts/train.py \
--dataset.repo_id=${HF_USER}/jack \
--policy.type=act \
--output_dir=outputs/train/act_jack \
--job_name=act_jack \
--device=cuda \
--wandb.enable=true
```
# Inference
还是使用control_robot.py中的record loop配置 **--robot.inference_time=true** 可以将手柄移出。
```bash
python lerobot/scripts/control_robot.py \
--robot.type=piper \
--robot.inference_time=true \
--control.type=record \
--control.fps=30 \
--control.single_task="move" \
--control.repo_id=$USER/eval_act_jack \
--control.num_episodes=1 \
--control.warmup_time_s=2 \
--control.episode_time_s=30 \
--control.reset_time_s=10 \
--control.push_to_hub=false \
--control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/latest/pretrained_model
```

View File

@@ -0,0 +1,31 @@
from Robotic_Arm.rm_robot_interface import *
armleft = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
armright = RoboticArm()
lefthandle = armleft.rm_create_robot_arm("169.254.128.18", 8080)
print("机械臂ID", lefthandle.id)
righthandle = armright.rm_create_robot_arm("169.254.128.19", 8080)
print("机械臂ID", righthandle.id)
# software_info = armleft.rm_get_arm_software_info()
# if software_info[0] == 0:
# print("\n================== Arm Software Information ==================")
# print("Arm Model: ", software_info[1]['product_version'])
# print("Algorithm Library Version: ", software_info[1]['algorithm_info']['version'])
# print("Control Layer Software Version: ", software_info[1]['ctrl_info']['version'])
# print("Dynamics Version: ", software_info[1]['dynamic_info']['model_version'])
# print("Planning Layer Software Version: ", software_info[1]['plan_info']['version'])
# print("==============================================================\n")
# else:
# print("\nFailed to get arm software information, Error code: ", software_info[0], "\n")
print("Left: ", armleft.rm_get_current_arm_state())
print("Left: ", armleft.rm_get_arm_all_state())
armleft.rm_movej_p()
# print("Right: ", armright.rm_get_current_arm_state())
# 断开所有连接,销毁线程
RoboticArm.rm_destory()

View File

@@ -0,0 +1,15 @@
from Robotic_Arm.rm_robot_interface import *
import time
# 实例化RoboticArm类
arm = RoboticArm(rm_thread_mode_e.RM_TRIPLE_MODE_E)
# 创建机械臂连接打印连接id
handle = arm.rm_create_robot_arm("192.168.3.18", 8080)
print(handle.id)
print(arm.rm_movep_follow([-0.330512, 0.255993, -0.161205, 3.141, 0.0, -1.57]))
time.sleep(2)
# print(arm.rm_movep_follow([0.3, 0, 0.3, 3.14, 0, 0]))
# time.sleep(2)
arm.rm_delete_robot_arm()

View File

View File

@@ -0,0 +1,4 @@
__pycache__/
*.pyc
*.pyo
*.pt

View File

@@ -0,0 +1,33 @@
[tool.poetry]
name = "shadow_camera"
version = "0.1.0"
description = "camera class, currently includes realsense"
readme = "README.md"
authors = ["Shadow <qiuchengzhan@gmail.com>"]
license = "MIT"
#include = ["realman_vision/pytransform/_pytransform.so",]
classifiers = [
"Operating System :: POSIX :: Linux amd64",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = ">=3.9"
numpy = ">=2.0.1"
opencv-python = ">=4.10.0.84"
pyrealsense2 = ">=2.55.1.6486"
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
pytest = ">=8.3"
black = ">=24.10.0"
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
[tool.poetry.group.dev.dependencies]
[build-system]
requires = ["poetry-core>=1.8.4"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,38 @@
from abc import ABCMeta, abstractmethod
class BaseCamera(metaclass=ABCMeta):
"""摄像头基类"""
def __init__(self):
pass
@abstractmethod
def start_camera(self):
"""启动相机"""
pass
@abstractmethod
def stop_camera(self):
"""停止相机"""
pass
@abstractmethod
def set_resolution(self, resolution_width, resolution_height):
"""设置相机彩色图像分辨率"""
pass
@abstractmethod
def set_frame_rate(self, fps):
"""设置相机彩色图像帧率"""
pass
@abstractmethod
def read_frame(self):
"""读取一帧彩色图像和深度图像"""
pass
@abstractmethod
def get_camera_intrinsics(self):
"""获取彩色图像和深度图像的内参"""
pass

View File

@@ -0,0 +1,38 @@
from shadow_camera import base_camera
import cv2
class OpenCVCamera(base_camera.BaseCamera):
"""基于OpenCV的摄像头类"""
def __init__(self, device_id=0):
"""初始化视频捕获
参数:
device_id: 摄像头设备ID
"""
self.cap = cv2.VideoCapture(device_id)
def get_frame(self):
"""获取当前帧
返回:
frame: 当前帧的图像数据,取不到时返回None
"""
ret, frame = self.cap.read()
return frame if ret else None
def get_frame_info(self):
"""获取当前帧信息
返回:
dict: 帧信息字典
"""
width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
channels = int(self.cap.get(cv2.CAP_PROP_FRAME_CHANNELS))
return {
'width': width,
'height': height,
'channels': channels
}

View File

@@ -0,0 +1,280 @@
import time
import logging
import numpy as np
import pyrealsense2 as rs
import base_camera
# 设置日志配置
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class RealSenseCamera(base_camera.BaseCamera):
"""Intel RealSense相机类"""
def __init__(self, serial_num=None, is_depth_frame=False):
"""
初始化相机对象
:param serial_num: 相机序列号默认为None
"""
super().__init__()
self._color_resolution = [640, 480]
self._depth_resolution = [640, 480]
self._color_frames_rate = 30
self._depth_frames_rate = 15
self.timestamp = 0
self.color_timestamp = 0
self.depth_timestamp = 0
self._colorizer = rs.colorizer()
self._config = rs.config()
self.is_depth_frame = is_depth_frame
self.camera_on = False
self.serial_num = serial_num
def get_serial_num(self):
serial_num = {}
context = rs.context()
devices = context.query_devices() # 获取所有设备
if len(context.devices) > 0:
for i, device in enumerate(devices):
serial_num[i] = device.get_info(rs.camera_info.serial_number)
logging.info(f"Detected serial numbers: {serial_num}")
return serial_num
def _set_config(self):
if self.serial_num is not None:
logging.info(f"Setting device with serial number: {self.serial_num}")
self._config.enable_device(self.serial_num)
self._config.enable_stream(
rs.stream.color,
self._color_resolution[0],
self._color_resolution[1],
rs.format.rgb8,
self._color_frames_rate,
)
if self.is_depth_frame:
self._config.enable_stream(
rs.stream.depth,
self._depth_resolution[0],
self._depth_resolution[1],
rs.format.z16,
self._depth_frames_rate,
)
def start_camera(self):
"""
启动相机并获取内参信息,如果后续调用帧对齐,则内参均为彩色内参
"""
self._pipeline = rs.pipeline()
if self.is_depth_frame:
self.point_cloud = rs.pointcloud()
self._align = rs.align(rs.stream.color)
self._set_config()
self.profile = self._pipeline.start(self._config)
if self.is_depth_frame:
self._depth_intrinsics = (
self.profile.get_stream(rs.stream.depth)
.as_video_stream_profile()
.get_intrinsics()
)
self._color_intrinsics = (
self.profile.get_stream(rs.stream.color)
.as_video_stream_profile()
.get_intrinsics()
)
self.camera_on = True
logging.info("Camera started successfully")
logging.info(
f"Camera started with color resolution: {self._color_resolution}, depth resolution: {self._depth_resolution}"
)
logging.info(
f"Color FPS: {self._color_frames_rate}, Depth FPS: {self._depth_frames_rate}"
)
def stop_camera(self):
"""
停止相机
"""
self._pipeline.stop()
self.camera_on = False
logging.info("Camera stopped")
def set_resolution(self, color_resolution, depth_resolution):
self._color_resolution = color_resolution
self._depth_resolution = depth_resolution
logging.info(
"Optional color resolution:"
"[320, 180] [320, 240] [424, 240] [640, 360] [640, 480]"
"[848, 480] [960, 540] [1280, 720] [1920, 1080]"
)
logging.info(
"Optional depth resolution:"
"[256, 144] [424, 240] [480, 270] [640, 360] [640, 400]"
"[640, 480] [848, 100] [848, 480] [1280, 720] [1280, 800]"
)
logging.info(f"Set color resolution to: {color_resolution}")
logging.info(f"Set depth resolution to: {depth_resolution}")
def set_frame_rate(self, color_fps, depth_fps):
self._color_frames_rate = color_fps
self._depth_frames_rate = depth_fps
logging.info("Optional color fps: 6 15 30 60 ")
logging.info("Optional depth fps: 6 15 30 60 90 100 300")
logging.info(f"Set color FPS to: {color_fps}")
logging.info(f"Set depth FPS to: {depth_fps}")
# TODO: 调节白平衡进行补偿
# def set_exposure(self, exposure):
def read_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
"""
读取一帧彩色图像和深度图像
:return: 彩色图像和深度图像的NumPy数组
"""
while not self.camera_on:
time.sleep(0.5)
color_image = None
depth_image = None
colorized_depth = None
point_cloud = None
try:
frames = self._pipeline.wait_for_frames()
if is_color:
color_frame = frames.get_color_frame()
color_image = np.asanyarray(color_frame.get_data())
else:
color_image = None
if is_depth:
depth_frame = frames.get_depth_frame()
depth_image = np.asanyarray(depth_frame.get_data())
else:
depth_image = None
colorized_depth = (
np.asanyarray(self._colorizer.colorize(depth_frame).get_data())
if is_colorized_depth
else None
)
point_cloud = (
np.asanyarray(self.point_cloud.calculate(depth_frame).get_vertices())
if is_point_cloud
else None
)
# 获取时间戳单位为ms对齐后color时间戳 > depth = aligned选择color
self.color_timestamp = color_frame.get_timestamp()
if self.is_depth_frame:
self.depth_timestamp = depth_frame.get_timestamp()
except Exception as e:
logging.warning(e)
if "Frame didn't arrive within 5000" in str(e):
logging.warning("Frame didn't arrive within 5000ms, resetting device")
self.stop_camera()
self.start_camera()
return color_image, depth_image, colorized_depth, point_cloud
def read_align_frame(self, is_color=True, is_depth=True, is_colorized_depth=False, is_point_cloud=False):
"""
读取一帧对齐的彩色图像和深度图像
:return: 彩色图像和深度图像的NumPy数组
"""
while not self.camera_on:
time.sleep(0.5)
try:
frames = self._pipeline.wait_for_frames()
aligned_frames = self._align.process(frames)
aligned_color_frame = aligned_frames.get_color_frame()
self._aligned_depth_frame = aligned_frames.get_depth_frame()
color_image = np.asanyarray(aligned_color_frame.get_data())
depth_image = np.asanyarray(self._aligned_depth_frame.get_data())
colorized_depth = (
np.asanyarray(
self._colorizer.colorize(self._aligned_depth_frame).get_data()
)
if is_colorized_depth
else None
)
if is_point_cloud:
points = self.point_cloud.calculate(self._aligned_depth_frame)
# 将元组数据转换为 NumPy 数组
point_cloud = np.array(
[[point[0], point[1], point[2]] for point in points.get_vertices()]
)
else:
point_cloud = None
# 获取时间戳单位为ms对齐后color时间戳 > depth = aligned选择color
self.timestamp = aligned_color_frame.get_timestamp()
return color_image, depth_image, colorized_depth, point_cloud
except Exception as e:
if "Frame didn't arrive within 5000" in str(e):
logging.warning("Frame didn't arrive within 5000ms, resetting device")
self.stop_camera()
self.start_camera()
# device = self.profile.get_device()
# device.hardware_reset()
def get_camera_intrinsics(self):
"""
获取彩色图像和深度图像的内参信息
:return: 彩色图像和深度图像的内参信息
"""
# 宽高:.width, .height; 焦距:.fx, .fy; 像素坐标:.ppx, .ppy; 畸变系数:.coeffs
logging.info("Getting camera intrinsics")
logging.info(
"Width and height: .width, .height; Focal length: .fx, .fy; Pixel coordinates: .ppx, .ppy; Distortion coefficient: .coeffs"
)
return self._color_intrinsics, self._depth_intrinsics
def get_3d_camera_coordinate(self, depth_pixel, align=False):
"""
获取深度相机坐标系下的三维坐标
:param depth_pixel:深度像素坐标
:param align: 是否对齐
:return: 深度值和相机坐标
"""
if not hasattr(self, "_aligned_depth_frame"):
raise AttributeError(
"Aligned depth frame not set. Call read_align_frame() first."
)
distance = self._aligned_depth_frame.get_distance(
depth_pixel[0], depth_pixel[1]
)
intrinsics = self._color_intrinsics if align else self._depth_intrinsics
camera_coordinate = rs.rs2_deproject_pixel_to_point(
intrinsics, depth_pixel, distance
)
return distance, camera_coordinate
if __name__ == "__main__":
camera = RealSenseCamera(is_depth_frame=False)
camera.get_serial_num()
camera.start_camera()
# camera.set_frame_rate(60, 60)
color_image, depth_image, colorized_depth, point_cloud = camera.read_frame()
camera.stop_camera()
logging.info(f"Color image shape: {color_image.shape}")
# logging.info(f"Depth image shape: {depth_image.shape}")
# logging.info(f"Colorized depth image shape: {colorized_depth.shape}")
# logging.info(f"Point cloud shape: {point_cloud.shape}")
logging.info(f"Color timestamp: {camera.timestamp}")
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
logging.info(f"Color timestamp: {camera.color_timestamp}")
# logging.info(f"Depth timestamp: {camera.depth_timestamp}")
logging.info("Test passed")

View File

@@ -0,0 +1,101 @@
import pyrealsense2 as rs
import numpy as np
import h5py
import time
import threading
import keyboard # 用于监听键盘输入
# 全局变量
is_recording = False # 标志位,控制录制状态
color_images = [] # 存储彩色图像
depth_images = [] # 存储深度图像
timestamps = [] # 存储时间戳
# 配置D435相机
def configure_camera():
pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30) # 彩色图像流
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30) # 深度图像流
pipeline.start(config)
return pipeline
# 监听键盘输入,控制录制状态
def listen_for_keyboard():
global is_recording
while True:
if keyboard.is_pressed('s'): # 按下 's' 开始录制
is_recording = True
print("Recording started.")
time.sleep(0.5) # 防止重复触发
elif keyboard.is_pressed('q'): # 按下 'q' 停止录制
is_recording = False
print("Recording stopped.")
time.sleep(0.5) # 防止重复触发
elif keyboard.is_pressed('e'): # 按下 'e' 退出程序
print("Exiting program.")
exit()
time.sleep(0.1)
# 采集图像数据
def capture_frames(pipeline):
global is_recording, color_images, depth_images, timestamps
try:
while True:
if is_recording:
frames = pipeline.wait_for_frames()
color_frame = frames.get_color_frame()
depth_frame = frames.get_depth_frame()
if not color_frame or not depth_frame:
continue
# 获取当前时间戳
timestamp = time.time()
# 将图像转换为numpy数组
color_image = np.asanyarray(color_frame.get_data())
depth_image = np.asanyarray(depth_frame.get_data())
# 存储数据
color_images.append(color_image)
depth_images.append(depth_image)
timestamps.append(timestamp)
print(f"Captured frame at {timestamp}")
else:
time.sleep(0.1) # 如果未录制,等待一段时间
finally:
pipeline.stop()
# 保存为HDF5文件
def save_to_hdf5(color_images, depth_images, timestamps, filename="output.h5"):
with h5py.File(filename, "w") as f:
f.create_dataset("color_images", data=np.array(color_images), compression="gzip")
f.create_dataset("depth_images", data=np.array(depth_images), compression="gzip")
f.create_dataset("timestamps", data=np.array(timestamps), compression="gzip")
print(f"Data saved to {filename}")
# 主函数
def main():
global is_recording, color_images, depth_images, timestamps
# 启动键盘监听线程
keyboard_thread = threading.Thread(target=listen_for_keyboard)
keyboard_thread.daemon = True
keyboard_thread.start()
# 配置相机
pipeline = configure_camera()
# 开始采集图像
capture_frames(pipeline)
# 录制结束后保存数据
if color_images and depth_images and timestamps:
save_to_hdf5(color_images, depth_images, timestamps, "mobile_aloha_data.h5")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,152 @@
import os
import cv2
import time
import numpy as np
from os import path
import pyrealsense2 as rs
from shadow_camera import realsense
import logging
def test_camera():
camera = realsense.RealSenseCamera('241122071186')
camera.start_camera()
while True:
# result = camera.read_align_frame()
# if result is None:
# print('is None')
# continue
# start_time = time.time()
color_image, depth_image, colorized_depth, vtx = camera.read_frame()
color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR)
print(f"color_image: {color_image.shape}")
# print(f"Time: {end_time - start_time}")
cv2.imshow("bgr_image", color_image)
if cv2.waitKey(1) & 0xFF == ord("q"):
break
camera.stop_camera()
def test_get_serial_num():
camera = realsense.RealSenseCamera()
device = camera.get_serial_num()
class CameraCapture:
def __init__(self, camera_serial_num=None, save_dir="./save"):
self._camera_serial_num = camera_serial_num
self._color_save_dir = path.join(save_dir, "color")
self._depth_save_dir = path.join(save_dir, "depth")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(self._color_save_dir, exist_ok=True)
os.makedirs(self._depth_save_dir, exist_ok=True)
def get_serial_num(self):
self._camera_serial_num = {}
camera_names = ["left", "right", "head", "table"]
context = rs.context()
devices = context.query_devices() # 获取所有设备
if len(context.devices) > 0:
for i, device in enumerate(devices):
self._camera_serial_num[camera_names[i]] = device.get_info(
rs.camera_info.serial_number
)
print(self._camera_serial_num)
return self._camera_serial_num
def start_camera(self):
if self._camera_serial_num is None:
self.get_serial_num()
self._camera_left = realsense.RealSenseCamera(self._camera_serial_num["left"])
self._camera_right = realsense.RealSenseCamera(self._camera_serial_num["right"])
self._camera_head = realsense.RealSenseCamera(self._camera_serial_num["head"])
self._camera_left.start_camera()
self._camera_right.start_camera()
self._camera_head.start_camera()
def stop_camera(self):
self._camera_left.stop_camera()
self._camera_right.stop_camera()
self._camera_head.stop_camera()
def _save_datas(self, timestamp, color_image, depth_image, camera_name):
color_filename = path.join(
self._color_save_dir, f"{timestamp}" + camera_name + ".jpg"
)
depth_filename = path.join(
self._depth_save_dir, f"{timestamp}" + camera_name + ".png"
)
cv2.imwrite(color_filename, color_image)
cv2.imwrite(depth_filename, depth_image)
def capture_images(self):
while True:
(
color_image_left,
depth_image_left,
_,
_,
) = self._camera_left.read_align_frame()
(
color_image_right,
depth_image_right,
_,
_,
) = self._camera_right.read_align_frame()
(
color_image_head,
depth_image_head,
_,
point_cloud3,
) = self._camera_head.read_align_frame()
bgr_color_image_left = cv2.cvtColor(color_image_left, cv2.COLOR_RGB2BGR)
bgr_color_image_right = cv2.cvtColor(color_image_right, cv2.COLOR_RGB2BGR)
bgr_color_image_head = cv2.cvtColor(color_image_head, cv2.COLOR_RGB2BGR)
timestamp = time.time() * 1000
cv2.imshow("Camera left", bgr_color_image_left)
cv2.imshow("Camera right", bgr_color_image_right)
cv2.imshow("Camera head", bgr_color_image_head)
# self._save_datas(
# timestamp, bgr_color_image_left, depth_image_left, "left"
# )
# self._save_datas(
# timestamp, bgr_color_image_right, depth_image_right, "right"
# )
# self._save_datas(
# timestamp, bgr_color_image_head, depth_image_head, "head"
# )
if cv2.waitKey(1) & 0xFF == ord("q"):
break
cv2.destroyAllWindows()
if __name__ == "__main__":
#test_camera()
test_get_serial_num()
"""
输入相机序列号制定左右相机:
dict{'left': '241222075132', 'right': '242322076532', 'head': '242322076532'}
保存路径:
str./save
输入为空,自动分配相机序列号(不指定左、右、头部),保存路径为./save
"""
# capture = CameraCapture()
# capture.get_serial_num()
# test_get_serial_num()
# capture.start_camera()
# capture.capture_images()
# capture.stop_camera()

View File

@@ -0,0 +1,71 @@
import pytest
import pyrealsense2 as rs
from shadow_camera.realsense import RealSenseCamera
class TestRealSenseCamera:
@pytest.fixture(autouse=True)
def setup_camera(self):
self.camera = RealSenseCamera()
def test_get_serial_num(self):
serial_nums = self.camera.get_serial_num()
assert isinstance(serial_nums, dict)
assert len(serial_nums) > 0
def test_start_stop_camera(self):
self.camera.start_camera()
assert self.camera.camera_on is True
self.camera.stop_camera()
assert self.camera.camera_on is False
def test_set_resolution(self):
color_resolution = [1280, 720]
depth_resolution = [1280, 720]
self.camera.set_resolution(color_resolution, depth_resolution)
assert self.camera._color_resolution == color_resolution
assert self.camera._depth_resolution == depth_resolution
def test_set_frame_rate(self):
color_fps = 60
depth_fps = 60
self.camera.set_frame_rate(color_fps, depth_fps)
assert self.camera._color_frames_rate == color_fps
assert self.camera._depth_frames_rate == depth_fps
def test_read_frame(self):
self.camera.start_camera()
color_image, depth_image, colorized_depth, point_cloud = (
self.camera.read_frame()
)
assert color_image is not None
assert depth_image is not None
self.camera.stop_camera()
def test_read_align_frame(self):
self.camera.start_camera()
color_image, depth_image, colorized_depth, point_cloud = (
self.camera.read_align_frame()
)
assert color_image is not None
assert depth_image is not None
self.camera.stop_camera()
def test_get_camera_intrinsics(self):
self.camera.start_camera()
color_intrinsics, depth_intrinsics = self.camera.get_camera_intrinsics()
assert color_intrinsics is not None
assert depth_intrinsics is not None
self.camera.stop_camera()
def test_get_3d_camera_coordinate(self):
self.camera.start_camera()
# 先调用 read_align_frame 方法以确保 _aligned_depth_frame 被设置
self.camera.read_align_frame()
depth_pixel = [320, 240]
distance, camera_coordinate = self.camera.get_3d_camera_coordinate(
depth_pixel, align=True
)
assert distance > 0
assert len(camera_coordinate) == 3
self.camera.stop_camera()

View File

@@ -0,0 +1,10 @@
__pycache__/
build/
devel/
dist/
data/
.catkin_workspace
*.pyc
*.pyo
*.pt
.vscode/

View File

@@ -0,0 +1,89 @@
# ACT: Action Chunking with Transformers
### *New*: [ACT tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing)
TL;DR: if your ACT policy is jerky or pauses in the middle of an episode, just train for longer! Success rate and smoothness can improve way after loss plateaus.
#### Project Website: https://tonyzhaozh.github.io/aloha/
This repo contains the implementation of ACT, together with 2 simulated environments:
Transfer Cube and Bimanual Insertion. You can train and evaluate ACT in sim or real.
For real, you would also need to install [ALOHA](https://github.com/tonyzhaozh/aloha).
### Updates:
You can find all scripted/human demo for simulated environments [here](https://drive.google.com/drive/folders/1gPR03v05S1xiInoVJn7G7VJ9pDCnxq9O?usp=share_link).
### Repo Structure
- ``imitate_episodes.py`` Train and Evaluate ACT
- ``policy.py`` An adaptor for ACT policy
- ``detr`` Model definitions of ACT, modified from DETR
- ``sim_env.py`` Mujoco + DM_Control environments with joint space control
- ``ee_sim_env.py`` Mujoco + DM_Control environments with EE space control
- ``scripted_policy.py`` Scripted policies for sim environments
- ``constants.py`` Constants shared across files
- ``utils.py`` Utils such as data loading and helper functions
- ``visualize_episodes.py`` Save videos from a .hdf5 dataset
### Installation
conda create -n aloha python=3.8.10
conda activate aloha
pip install torchvision
pip install torch
pip install pyquaternion
pip install pyyaml
pip install rospkg
pip install pexpect
pip install mujoco==2.3.7
pip install dm_control==1.0.14
pip install opencv-python
pip install matplotlib
pip install einops
pip install packaging
pip install h5py
pip install ipython
cd act/detr && pip install -e .
### Example Usages
To set up a new terminal, run:
conda activate aloha
cd <path to act repo>
### Simulated experiments
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
To generated 50 episodes of scripted data, run:
python3 record_sim_episodes.py \
--task_name sim_transfer_cube_scripted \
--dataset_dir <data save dir> \
--num_episodes 50
To can add the flag ``--onscreen_render`` to see real-time rendering.
To visualize the episode after it is collected, run
python3 visualize_episodes.py --dataset_dir <data save dir> --episode_idx 0
To train ACT:
# Transfer Cube task
python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
--num_epochs 2000 --lr 1e-5 \
--seed 0
To evaluate the policy, run the same command but add ``--eval``. This loads the best validation checkpoint.
The success rate should be around 90% for transfer cube, and around 50% for insertion.
To enable temporal ensembling, add flag ``--temporal_agg``.
Videos will be saved to ``<ckpt_dir>`` for each rollout.
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.
For real-world data where things can be harder to model, train for at least 5000 epochs or 3-4 times the length after the loss has plateaued.
Please refer to [tuning tips](https://docs.google.com/document/d/1FVIZfoALXg_ZkYKaYVh-qOlaXveq5CtvJHXkY25eYhs/edit?usp=sharing) for more info.

View File

@@ -0,0 +1,74 @@
robot_env: {
# TODO change the path to the correct one
rm_left_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_left_arm.yaml',
rm_right_arm: '/home/rm/aloha/shadow_rm_aloha/config/rm_right_arm.yaml',
arm_axis: 6,
head_camera: '215222076892',
bottom_camera: '215222076981',
left_camera: '152122078151',
right_camera: '152122073489',
# init_left_arm_angle: [0.226, 21.180, 91.304, -0.515, 67.486, 2.374, 0.9],
# init_right_arm_angle: [-1.056, 33.057, 84.376, -0.204, 66.357, -3.236, 0.9]
init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
}
dataset_dir: '/home/rm/aloha/shadow_rm_aloha/data/dataset/20250103'
checkpoint_dir: '/home/rm/aloha/shadow_rm_act/data'
# checkpoint_name: 'policy_best.ckpt'
checkpoint_name: 'policy_9500.ckpt'
state_dim: 14
save_episode: True
num_rollouts: 50 #训练期间要收集的 rollout轨迹数量
real_robot: True
policy_class: 'ACT'
onscreen_render: False
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right']
episode_len: 300 #episode 的最大长度(时间步数)。
task_name: 'aloha_01_11.28'
temporal_agg: False #是否使用时间聚合
batch_size: 8 #训练期间每批的样本数。
seed: 1000 #随机种子。
chunk_size: 30 #用于处理序列的块大小
eval_every: 1 #每隔 eval_every 步评估一次模型。
num_steps: 10000 #训练的总步数。
validate_every: 1 #每隔 validate_every 步验证一次模型。
save_every: 500 #每隔 save_every 步保存一次检查点。
load_pretrain: False #是否加载预训练模型。
resume_ckpt_path:
name_filter: # TODO
skip_mirrored_data: False #是否跳过镜像数据(例如用于基于对称性的数据增强)。
stats_dir:
sample_weights:
train_ratio: 0.8 #用于训练的数据比例(其余数据用于验证)
policy_config: {
hidden_dim: 512, # Size of the embeddings (dimension of the transformer)
state_dim: 14, # Dimension of the state
position_embedding: 'sine', # ('sine', 'learned').Type of positional embedding to use on top of the image features
lr_backbone: 1.0e-5,
masks: False, # If true, the model masks the non-visible pixels
backbone: 'resnet18',
dilation: False, # If true, we replace stride with dilation in the last convolutional block (DC5)
dropout: 0.1, # Dropout applied in the transformer
nheads: 8,
dim_feedforward: 3200, # Intermediate size of the feedforward layers in the transformer blocks
enc_layers: 4, # Number of encoding layers in the transformer
dec_layers: 7, # Number of decoding layers in the transformer
pre_norm: False, # If true, apply LayerNorm to the input instead of the output of the MultiheadAttention and FeedForward
num_queries: 30,
camera_names: ['cam_high', 'cam_low', 'cam_left', 'cam_right'],
vq: False,
vq_class: none,
vq_dim: 64,
action_dim: 14,
no_encoder: False,
lr: 1.0e-5,
weight_decay: 1.0e-4,
kl_weight: 10,
# lr_drop: 200,
# clip_max_norm: 0.1,
}

View File

@@ -0,0 +1,267 @@
import numpy as np
import collections
import os
from constants import DT, XML_DIR, START_ARM_POSE
from constants import PUPPET_GRIPPER_POSITION_CLOSE
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
from src.shadow_act.utils.utils import sample_box_pose, sample_insertion_pose
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
import IPython
e = IPython.embed
def make_ee_sim_env(task_name):
"""
Environment for simulated robot bi-manual manipulation, with end-effector control.
Action space: [left_arm_pose (7), # position and quaternion for end effector
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_pose (7), # position and quaternion for end effector
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
else:
raise NotImplementedError
return env
class BimanualViperXEETask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
a_len = len(action) // 2
action_left = action[:a_len]
action_right = action[a_len:]
# set mocap position and quat
# left
np.copyto(physics.data.mocap_pos[0], action_left[:3])
np.copyto(physics.data.mocap_quat[0], action_left[3:7])
# right
np.copyto(physics.data.mocap_pos[1], action_right[:3])
np.copyto(physics.data.mocap_quat[1], action_right[3:7])
# set gripper
g_left_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_left[7])
g_right_ctrl = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(action_right[7])
np.copyto(physics.data.ctrl, np.array([g_left_ctrl, -g_left_ctrl, g_right_ctrl, -g_right_ctrl]))
def initialize_robots(self, physics):
# reset joint position
physics.named.data.qpos[:16] = START_ARM_POSE
# reset mocap to align with end effector
# to obtain these numbers:
# (1) make an ee_sim env and reset to the same start_pose
# (2) get env._physics.named.data.xpos['vx300s_left/gripper_link']
# get env._physics.named.data.xquat['vx300s_left/gripper_link']
# repeat the same for right side
np.copyto(physics.data.mocap_pos[0], [-0.31718881, 0.5, 0.29525084])
np.copyto(physics.data.mocap_quat[0], [1, 0, 0, 0])
# right
np.copyto(physics.data.mocap_pos[1], np.array([0.31718881, 0.49999888, 0.29525084]))
np.copyto(physics.data.mocap_quat[1], [1, 0, 0, 0])
# reset gripper control
close_gripper_control = np.array([
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
PUPPET_GRIPPER_POSITION_CLOSE,
-PUPPET_GRIPPER_POSITION_CLOSE,
])
np.copyto(physics.data.ctrl, close_gripper_control)
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
# note: it is important to do .copy()
obs = collections.OrderedDict()
obs['qpos'] = self.get_qpos(physics)
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
# used in scripted policy to obtain starting pose
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
# used when replaying joint trajectory
obs['gripper_ctrl'] = physics.data.ctrl.copy()
return obs
def get_reward(self, physics):
raise NotImplementedError
class TransferCubeEETask(BimanualViperXEETask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize box position
cube_pose = sample_box_pose()
box_start_idx = physics.model.name2id('red_box_joint', 'joint')
np.copyto(physics.data.qpos[box_start_idx : box_start_idx + 7], cube_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionEETask(BimanualViperXEETask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
self.initialize_robots(physics)
# randomize peg and socket position
peg_pose, socket_pose = sample_insertion_pose()
id2index = lambda j_id: 16 + (j_id - 16) * 7 # first 16 is robot qpos, 7 is pose dim # hacky
peg_start_id = physics.model.name2id('red_peg_joint', 'joint')
peg_start_idx = id2index(peg_start_id)
np.copyto(physics.data.qpos[peg_start_idx : peg_start_idx + 7], peg_pose)
# print(f"randomized cube position to {cube_position}")
socket_start_id = physics.model.name2id('blue_socket_joint', 'joint')
socket_start_idx = id2index(socket_start_id)
np.copyto(physics.data.qpos[socket_start_idx : socket_start_idx + 7], socket_pose)
# print(f"randomized cube position to {cube_position}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
("socket-2", "table") in all_contact_pairs or \
("socket-3", "table") in all_contact_pairs or \
("socket-4", "table") in all_contact_pairs
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
("red_peg", "socket-2") in all_contact_pairs or \
("red_peg", "socket-3") in all_contact_pairs or \
("red_peg", "socket-4") in all_contact_pairs
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward

View File

@@ -0,0 +1,36 @@
[tool.poetry]
name = "shadow_act"
version = "0.1.0"
description = "Embodied data, ACT and other methods; training and verification function packages"
readme = "README.md"
authors = ["Shadow <qiuchengzhan@gmail.com>"]
license = "MIT"
# include = ["realman_vision/pytransform/_pytransform.so",]
classifiers = [
"Operating System :: POSIX :: Linux amd64",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = ">=3.9"
wandb = ">=0.18.0"
einops = ">=0.8.0"
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
pytest = ">=8.3"
black = ">=24.10.0"
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
[tool.poetry.group.dev.dependencies]
[build-system]
requires = ["poetry-core>=1.8.4"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,189 @@
import time
import os
import numpy as np
import argparse
import matplotlib.pyplot as plt
import h5py
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
from sim_env import make_sim_env, BOX_POSE
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
import IPython
e = IPython.embed
def main(args):
"""
Generate demonstration data in simulation.
First rollout the policy (defined in ee space) in ee_sim_env. Obtain the joint trajectory.
Replace the gripper joint positions with the commanded joint position.
Replay this joint trajectory (as action sequence) in sim_env, and record all observations.
Save this episode of data, and continue to next episode of data collection.
"""
task_name = args['task_name']
dataset_dir = args['dataset_dir']
num_episodes = args['num_episodes']
onscreen_render = args['onscreen_render']
inject_noise = False
render_cam_name = 'angle'
if not os.path.isdir(dataset_dir):
os.makedirs(dataset_dir, exist_ok=True)
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
if task_name == 'sim_transfer_cube_scripted':
policy_cls = PickAndTransferPolicy
elif task_name == 'sim_insertion_scripted':
policy_cls = InsertionPolicy
else:
raise NotImplementedError
success = []
for episode_idx in range(num_episodes):
print(f'{episode_idx=}')
print('Rollout out EE space scripted policy')
# setup the environment
env = make_ee_sim_env(task_name)
ts = env.reset()
episode = [ts]
policy = policy_cls(inject_noise)
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for step in range(episode_len):
action = policy(ts)
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.002)
plt.close()
episode_return = np.sum([ts.reward for ts in episode[1:]])
episode_max_reward = np.max([ts.reward for ts in episode[1:]])
if episode_max_reward == env.task.max_reward:
print(f"{episode_idx=} Successful, {episode_return=}")
else:
print(f"{episode_idx=} Failed")
joint_traj = [ts.observation['qpos'] for ts in episode]
# replace gripper pose with gripper control
gripper_ctrl_traj = [ts.observation['gripper_ctrl'] for ts in episode]
for joint, ctrl in zip(joint_traj, gripper_ctrl_traj):
left_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[0])
right_ctrl = PUPPET_GRIPPER_POSITION_NORMALIZE_FN(ctrl[2])
joint[6] = left_ctrl
joint[6+7] = right_ctrl
subtask_info = episode[0].observation['env_state'].copy() # box pose at step 0
# clear unused variables
del env
del episode
del policy
# setup the environment
print('Replaying joint commands')
env = make_sim_env(task_name)
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
ts = env.reset()
episode_replay = [ts]
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for t in range(len(joint_traj)): # note: this will increase episode length by 1
action = joint_traj[t]
ts = env.step(action)
episode_replay.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.02)
episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
episode_max_reward = np.max([ts.reward for ts in episode_replay[1:]])
if episode_max_reward == env.task.max_reward:
success.append(1)
print(f"{episode_idx=} Successful, {episode_return=}")
else:
success.append(0)
print(f"{episode_idx=} Failed")
plt.close()
"""
For each timestep:
observations
- images
- each_cam_name (480, 640, 3) 'uint8'
- qpos (14,) 'float64'
- qvel (14,) 'float64'
action (14,) 'float64'
"""
data_dict = {
'/observations/qpos': [],
'/observations/qvel': [],
'/action': [],
}
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
# truncate here to be consistent
joint_traj = joint_traj[:-1]
episode_replay = episode_replay[:-1]
# len(joint_traj) i.e. actions: max_timesteps
# len(episode_replay) i.e. time steps: max_timesteps + 1
max_timesteps = len(joint_traj)
while joint_traj:
action = joint_traj.pop(0)
ts = episode_replay.pop(0)
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/action'].append(action)
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
# HDF5
t0 = time.time()
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}')
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
root.attrs['sim'] = True
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in camera_names:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
# compression='gzip',compression_opts=2,)
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
qvel = obs.create_dataset('qvel', (max_timesteps, 14))
action = root.create_dataset('action', (max_timesteps, 14))
for name, array in data_dict.items():
root[name][...] = array
print(f'Saving: {time.time() - t0:.1f} secs\n')
print(f'Saved to {dataset_dir}')
print(f'Success: {np.sum(success)} / {len(success)}')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset saving dir', required=True)
parser.add_argument('--num_episodes', action='store', type=int, help='num_episodes', required=False)
parser.add_argument('--onscreen_render', action='store_true')
main(vars(parser.parse_args()))

View File

@@ -0,0 +1,194 @@
import numpy as np
import matplotlib.pyplot as plt
from pyquaternion import Quaternion
from constants import SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
import IPython
e = IPython.embed
class BasePolicy:
def __init__(self, inject_noise=False):
self.inject_noise = inject_noise
self.step_count = 0
self.left_trajectory = None
self.right_trajectory = None
def generate_trajectory(self, ts_first):
raise NotImplementedError
@staticmethod
def interpolate(curr_waypoint, next_waypoint, t):
t_frac = (t - curr_waypoint["t"]) / (next_waypoint["t"] - curr_waypoint["t"])
curr_xyz = curr_waypoint['xyz']
curr_quat = curr_waypoint['quat']
curr_grip = curr_waypoint['gripper']
next_xyz = next_waypoint['xyz']
next_quat = next_waypoint['quat']
next_grip = next_waypoint['gripper']
xyz = curr_xyz + (next_xyz - curr_xyz) * t_frac
quat = curr_quat + (next_quat - curr_quat) * t_frac
gripper = curr_grip + (next_grip - curr_grip) * t_frac
return xyz, quat, gripper
def __call__(self, ts):
# generate trajectory at first timestep, then open-loop execution
if self.step_count == 0:
self.generate_trajectory(ts)
# obtain left and right waypoints
if self.left_trajectory[0]['t'] == self.step_count:
self.curr_left_waypoint = self.left_trajectory.pop(0)
next_left_waypoint = self.left_trajectory[0]
if self.right_trajectory[0]['t'] == self.step_count:
self.curr_right_waypoint = self.right_trajectory.pop(0)
next_right_waypoint = self.right_trajectory[0]
# interpolate between waypoints to obtain current pose and gripper command
left_xyz, left_quat, left_gripper = self.interpolate(self.curr_left_waypoint, next_left_waypoint, self.step_count)
right_xyz, right_quat, right_gripper = self.interpolate(self.curr_right_waypoint, next_right_waypoint, self.step_count)
# Inject noise
if self.inject_noise:
scale = 0.01
left_xyz = left_xyz + np.random.uniform(-scale, scale, left_xyz.shape)
right_xyz = right_xyz + np.random.uniform(-scale, scale, right_xyz.shape)
action_left = np.concatenate([left_xyz, left_quat, [left_gripper]])
action_right = np.concatenate([right_xyz, right_quat, [right_gripper]])
self.step_count += 1
return np.concatenate([action_left, action_right])
class PickAndTransferPolicy(BasePolicy):
def generate_trajectory(self, ts_first):
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
box_info = np.array(ts_first.observation['env_state'])
box_xyz = box_info[:3]
box_quat = box_info[3:]
# print(f"Generate trajectory for {box_xyz=}")
gripper_pick_quat = Quaternion(init_mocap_pose_right[3:])
gripper_pick_quat = gripper_pick_quat * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
meet_left_quat = Quaternion(axis=[1.0, 0.0, 0.0], degrees=90)
meet_xyz = np.array([0, 0.5, 0.25])
self.left_trajectory = [
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
{"t": 100, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # approach meet position
{"t": 260, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 1}, # move to meet position
{"t": 310, "xyz": meet_xyz + np.array([0.02, 0, -0.02]), "quat": meet_left_quat.elements, "gripper": 0}, # close gripper
{"t": 360, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # move left
{"t": 400, "xyz": meet_xyz + np.array([-0.1, 0, -0.02]), "quat": np.array([1, 0, 0, 0]), "gripper": 0}, # stay
]
self.right_trajectory = [
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
{"t": 90, "xyz": box_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat.elements, "gripper": 1}, # approach the cube
{"t": 130, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 1}, # go down
{"t": 170, "xyz": box_xyz + np.array([0, 0, -0.015]), "quat": gripper_pick_quat.elements, "gripper": 0}, # close gripper
{"t": 200, "xyz": meet_xyz + np.array([0.05, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 0}, # approach meet position
{"t": 220, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 0}, # move to meet position
{"t": 310, "xyz": meet_xyz, "quat": gripper_pick_quat.elements, "gripper": 1}, # open gripper
{"t": 360, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # move to right
{"t": 400, "xyz": meet_xyz + np.array([0.1, 0, 0]), "quat": gripper_pick_quat.elements, "gripper": 1}, # stay
]
class InsertionPolicy(BasePolicy):
def generate_trajectory(self, ts_first):
init_mocap_pose_right = ts_first.observation['mocap_pose_right']
init_mocap_pose_left = ts_first.observation['mocap_pose_left']
peg_info = np.array(ts_first.observation['env_state'])[:7]
peg_xyz = peg_info[:3]
peg_quat = peg_info[3:]
socket_info = np.array(ts_first.observation['env_state'])[7:]
socket_xyz = socket_info[:3]
socket_quat = socket_info[3:]
gripper_pick_quat_right = Quaternion(init_mocap_pose_right[3:])
gripper_pick_quat_right = gripper_pick_quat_right * Quaternion(axis=[0.0, 1.0, 0.0], degrees=-60)
gripper_pick_quat_left = Quaternion(init_mocap_pose_right[3:])
gripper_pick_quat_left = gripper_pick_quat_left * Quaternion(axis=[0.0, 1.0, 0.0], degrees=60)
meet_xyz = np.array([0, 0.5, 0.15])
lift_right = 0.00715
self.left_trajectory = [
{"t": 0, "xyz": init_mocap_pose_left[:3], "quat": init_mocap_pose_left[3:], "gripper": 0}, # sleep
{"t": 120, "xyz": socket_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # approach the cube
{"t": 170, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 1}, # go down
{"t": 220, "xyz": socket_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # close gripper
{"t": 285, "xyz": meet_xyz + np.array([-0.1, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # approach meet position
{"t": 340, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements,"gripper": 0}, # insertion
{"t": 400, "xyz": meet_xyz + np.array([-0.05, 0, 0]), "quat": gripper_pick_quat_left.elements, "gripper": 0}, # insertion
]
self.right_trajectory = [
{"t": 0, "xyz": init_mocap_pose_right[:3], "quat": init_mocap_pose_right[3:], "gripper": 0}, # sleep
{"t": 120, "xyz": peg_xyz + np.array([0, 0, 0.08]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # approach the cube
{"t": 170, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 1}, # go down
{"t": 220, "xyz": peg_xyz + np.array([0, 0, -0.03]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # close gripper
{"t": 285, "xyz": meet_xyz + np.array([0.1, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # approach meet position
{"t": 340, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
{"t": 400, "xyz": meet_xyz + np.array([0.05, 0, lift_right]), "quat": gripper_pick_quat_right.elements, "gripper": 0}, # insertion
]
def test_policy(task_name):
# example rolling out pick_and_transfer policy
onscreen_render = True
inject_noise = False
# setup the environment
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
if 'sim_transfer_cube' in task_name:
env = make_ee_sim_env('sim_transfer_cube')
elif 'sim_insertion' in task_name:
env = make_ee_sim_env('sim_insertion')
else:
raise NotImplementedError
for episode_idx in range(2):
ts = env.reset()
episode = [ts]
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
policy = PickAndTransferPolicy(inject_noise)
for step in range(episode_len):
action = policy(ts)
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)
plt.close()
episode_return = np.sum([ts.reward for ts in episode[1:]])
if episode_return > 0:
print(f"{episode_idx=} Successful, {episode_return=}")
else:
print(f"{episode_idx=} Failed")
if __name__ == '__main__':
test_task_name = 'sim_transfer_cube_scripted'
test_policy(test_task_name)

View File

@@ -0,0 +1,278 @@
import numpy as np
import os
import collections
import matplotlib.pyplot as plt
from dm_control import mujoco
from dm_control.rl import control
from dm_control.suite import base
from constants import DT, XML_DIR, START_ARM_POSE
from constants import PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN
from constants import MASTER_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
from constants import PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN
import IPython
e = IPython.embed
BOX_POSE = [None] # to be changed from outside
def make_sim_env(task_name):
"""
Environment for simulated robot bi-manual manipulation, with joint position control
Action space: [left_arm_qpos (6), # absolute joint position
left_gripper_positions (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_positions (1),] # normalized gripper position (0: close, 1: open)
Observation space: {"qpos": Concat[ left_arm_qpos (6), # absolute joint position
left_gripper_position (1), # normalized gripper position (0: close, 1: open)
right_arm_qpos (6), # absolute joint position
right_gripper_qpos (1)] # normalized gripper position (0: close, 1: open)
"qvel": Concat[ left_arm_qvel (6), # absolute joint velocity (rad)
left_gripper_velocity (1), # normalized gripper velocity (pos: opening, neg: closing)
right_arm_qvel (6), # absolute joint velocity (rad)
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeTask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionTask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
else:
raise NotImplementedError
return env
class BimanualViperXTask(base.Task):
def __init__(self, random=None):
super().__init__(random=random)
def before_step(self, action, physics):
left_arm_action = action[:6]
right_arm_action = action[7:7+6]
normalized_left_gripper_action = action[6]
normalized_right_gripper_action = action[7+6]
left_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_left_gripper_action)
right_gripper_action = PUPPET_GRIPPER_POSITION_UNNORMALIZE_FN(normalized_right_gripper_action)
full_left_gripper_action = [left_gripper_action, -left_gripper_action]
full_right_gripper_action = [right_gripper_action, -right_gripper_action]
env_action = np.concatenate([left_arm_action, full_left_gripper_action, right_arm_action, full_right_gripper_action])
super().before_step(env_action, physics)
return
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
super().initialize_episode(physics)
@staticmethod
def get_qpos(physics):
qpos_raw = physics.data.qpos.copy()
left_qpos_raw = qpos_raw[:8]
right_qpos_raw = qpos_raw[8:16]
left_arm_qpos = left_qpos_raw[:6]
right_arm_qpos = right_qpos_raw[:6]
left_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(left_qpos_raw[6])]
right_gripper_qpos = [PUPPET_GRIPPER_POSITION_NORMALIZE_FN(right_qpos_raw[6])]
return np.concatenate([left_arm_qpos, left_gripper_qpos, right_arm_qpos, right_gripper_qpos])
@staticmethod
def get_qvel(physics):
qvel_raw = physics.data.qvel.copy()
left_qvel_raw = qvel_raw[:8]
right_qvel_raw = qvel_raw[8:16]
left_arm_qvel = left_qvel_raw[:6]
right_arm_qvel = right_qvel_raw[:6]
left_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(left_qvel_raw[6])]
right_gripper_qvel = [PUPPET_GRIPPER_VELOCITY_NORMALIZE_FN(right_qvel_raw[6])]
return np.concatenate([left_arm_qvel, left_gripper_qvel, right_arm_qvel, right_gripper_qvel])
@staticmethod
def get_env_state(physics):
raise NotImplementedError
def get_observation(self, physics):
obs = collections.OrderedDict()
obs['qpos'] = self.get_qpos(physics)
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
return obs
def get_reward(self, physics):
# return whether left gripper is holding the box
raise NotImplementedError
class TransferCubeTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7:] = BOX_POSE[0]
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether left gripper is holding the box
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_left_gripper = ("red_box", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
touch_right_gripper = ("red_box", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_table = ("red_box", "table") in all_contact_pairs
reward = 0
if touch_right_gripper:
reward = 1
if touch_right_gripper and not touch_table: # lifted
reward = 2
if touch_left_gripper: # attempted transfer
reward = 3
if touch_left_gripper and not touch_table: # successful transfer
reward = 4
return reward
class InsertionTask(BimanualViperXTask):
def __init__(self, random=None):
super().__init__(random=random)
self.max_reward = 4
def initialize_episode(self, physics):
"""Sets the state of the environment at the start of each episode."""
# TODO Notice: this function does not randomize the env configuration. Instead, set BOX_POSE from outside
# reset qpos, control and box position
with physics.reset_context():
physics.named.data.qpos[:16] = START_ARM_POSE
np.copyto(physics.data.ctrl, START_ARM_POSE)
assert BOX_POSE[0] is not None
physics.named.data.qpos[-7*2:] = BOX_POSE[0] # two objects
# print(f"{BOX_POSE=}")
super().initialize_episode(physics)
@staticmethod
def get_env_state(physics):
env_state = physics.data.qpos.copy()[16:]
return env_state
def get_reward(self, physics):
# return whether peg touches the pin
all_contact_pairs = []
for i_contact in range(physics.data.ncon):
id_geom_1 = physics.data.contact[i_contact].geom1
id_geom_2 = physics.data.contact[i_contact].geom2
name_geom_1 = physics.model.id2name(id_geom_1, 'geom')
name_geom_2 = physics.model.id2name(id_geom_2, 'geom')
contact_pair = (name_geom_1, name_geom_2)
all_contact_pairs.append(contact_pair)
touch_right_gripper = ("red_peg", "vx300s_right/10_right_gripper_finger") in all_contact_pairs
touch_left_gripper = ("socket-1", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-2", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-3", "vx300s_left/10_left_gripper_finger") in all_contact_pairs or \
("socket-4", "vx300s_left/10_left_gripper_finger") in all_contact_pairs
peg_touch_table = ("red_peg", "table") in all_contact_pairs
socket_touch_table = ("socket-1", "table") in all_contact_pairs or \
("socket-2", "table") in all_contact_pairs or \
("socket-3", "table") in all_contact_pairs or \
("socket-4", "table") in all_contact_pairs
peg_touch_socket = ("red_peg", "socket-1") in all_contact_pairs or \
("red_peg", "socket-2") in all_contact_pairs or \
("red_peg", "socket-3") in all_contact_pairs or \
("red_peg", "socket-4") in all_contact_pairs
pin_touched = ("red_peg", "pin") in all_contact_pairs
reward = 0
if touch_left_gripper and touch_right_gripper: # touch both
reward = 1
if touch_left_gripper and touch_right_gripper and (not peg_touch_table) and (not socket_touch_table): # grasp both
reward = 2
if peg_touch_socket and (not peg_touch_table) and (not socket_touch_table): # peg and socket touching
reward = 3
if pin_touched: # successful insertion
reward = 4
return reward
def get_action(master_bot_left, master_bot_right):
action = np.zeros(14)
# arm action
action[:6] = master_bot_left.dxl.joint_states.position[:6]
action[7:7+6] = master_bot_right.dxl.joint_states.position[:6]
# gripper action
left_gripper_pos = master_bot_left.dxl.joint_states.position[7]
right_gripper_pos = master_bot_right.dxl.joint_states.position[7]
normalized_left_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(left_gripper_pos)
normalized_right_pos = MASTER_GRIPPER_POSITION_NORMALIZE_FN(right_gripper_pos)
action[6] = normalized_left_pos
action[7+6] = normalized_right_pos
return action
def test_sim_teleop():
""" Testing teleoperation in sim with ALOHA. Requires hardware and ALOHA repo to work. """
from interbotix_xs_modules.arm import InterbotixManipulatorXS
BOX_POSE[0] = [0.2, 0.5, 0.05, 1, 0, 0, 0]
# source of data
master_bot_left = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
robot_name=f'master_left', init_node=True)
master_bot_right = InterbotixManipulatorXS(robot_model="wx250s", group_name="arm", gripper_name="gripper",
robot_name=f'master_right', init_node=False)
# setup the environment
env = make_sim_env('sim_transfer_cube')
ts = env.reset()
episode = [ts]
# setup plotting
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['angle'])
plt.ion()
for t in range(1000):
action = get_action(master_bot_left, master_bot_right)
ts = env.step(action)
episode.append(ts)
plt_img.set_data(ts.observation['images']['angle'])
plt.pause(0.02)
if __name__ == '__main__':
test_sim_teleop()

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,575 @@
import os
import time
import yaml
import torch
import pickle
import dm_env
import logging
import collections
import numpy as np
import tracemalloc
from einops import rearrange
import matplotlib.pyplot as plt
from torchvision import transforms
from shadow_rm_robot.realman_arm import RmArm
from shadow_camera.realsense import RealSenseCamera
from shadow_act.models.latent_model import Latent_Model_Transformer
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
from shadow_act.utils.utils import set_seed
# 配置logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
# # 隐藏h5py的警告Creating converter from 7 to 5
# logging.getLogger("h5py").setLevel(logging.WARNING)
class RmActEvaluator:
def __init__(self, config, save_episode=True, num_rollouts=50):
"""
初始化Evaluator类
Args:
config (dict): 配置字典
checkpoint_name (str): 检查点名称
save_episode (bool): 是否保存每个episode
num_rollouts (int): 滚动次数
"""
self.config = config
self._seed = config["seed"]
self.robot_env = config["robot_env"]
self.checkpoint_dir = config["checkpoint_dir"]
self.checkpoint_name = config["checkpoint_name"]
self.save_episode = save_episode
self.num_rollouts = num_rollouts
self.state_dim = config["state_dim"]
self.real_robot = config["real_robot"]
self.policy_class = config["policy_class"]
self.onscreen_render = config["onscreen_render"]
self.camera_names = config["camera_names"]
self.max_timesteps = config["episode_len"]
self.task_name = config["task_name"]
self.temporal_agg = config["temporal_agg"]
self.onscreen_cam = "angle"
self.policy_config = config["policy_config"]
self.vq = config["policy_config"]["vq"]
# self.actuator_config = config["actuator_config"]
# self.use_actuator_net = self.actuator_config["actuator_network_dir"] is not None
self.stats = None
self.env = None
self.env_max_reward = 0
def _make_policy(self, policy_class, policy_config):
"""
根据策略类和配置创建策略对象
Args:
policy_class (str): 策略类名称
policy_config (dict): 策略配置字典
Returns:
policy: 创建的策略对象
"""
if policy_class == "ACT":
return ACTPolicy(policy_config)
elif policy_class == "CNNMLP":
return CNNMLPPolicy(policy_config)
elif policy_class == "Diffusion":
return DiffusionPolicy(policy_config)
else:
raise NotImplementedError(f"Policy class {policy_class} is not implemented")
def load_policy_and_stats(self):
"""
加载策略和统计数据
"""
checkpoint_path = os.path.join(self.checkpoint_dir, self.checkpoint_name)
logging.info(f"Loading policy from: {checkpoint_path}")
self.policy = self._make_policy(self.policy_class, self.policy_config)
# 加载模型并设置为评估模式
self.policy.load_state_dict(torch.load(checkpoint_path, weights_only=True))
self.policy.cuda()
self.policy.eval()
if self.vq:
vq_dim = self.config["policy_config"]["vq_dim"]
vq_class = self.config["policy_config"]["vq_class"]
self.latent_model = Latent_Model_Transformer(vq_dim, vq_dim, vq_class)
latent_model_checkpoint_path = os.path.join(
self.checkpoint_dir, "latent_model_last.ckpt"
)
self.latent_model.deserialize(torch.load(latent_model_checkpoint_path))
self.latent_model.eval()
self.latent_model.cuda()
logging.info(
f"Loaded policy from: {checkpoint_path}, latent model from: {latent_model_checkpoint_path}"
)
else:
logging.info(f"Loaded: {checkpoint_path}")
stats_path = os.path.join(self.checkpoint_dir, "dataset_stats.pkl")
with open(stats_path, "rb") as f:
self.stats = pickle.load(f)
def pre_process(self, state_qpos):
"""
预处理状态位置
Args:
state_qpos (np.array): 状态位置数组
Returns:
np.array: 预处理后的状态位置
"""
if self.policy_class == "Diffusion":
return ((state_qpos + 1) / 2) * (
self.stats["action_max"] - self.stats["action_min"]
) + self.stats["action_min"]
# 标准化处理,均值为 0标准差为 1
return (state_qpos - self.stats["qpos_mean"]) / self.stats["qpos_std"]
def post_process(self, action):
"""
后处理动作
Args:
action (np.array): 动作数组
Returns:
np.array: 后处理后的动作
"""
# 反标准化处理
return action * self.stats["action_std"] + self.stats["action_mean"]
def get_image_torch(self, timestep, camera_names, random_crop_resize=False):
"""
获取图像
Args:
timestep (object): 时间步对象
camera_names (list): 相机名称列表
random_crop_resize (bool): 是否随机裁剪和调整大小
Returns:
torch.Tensor: 处理后的图像,归一化(num_cameras, channels, height, width)
"""
current_images = []
for cam_name in camera_names:
current_image = rearrange(
timestep.observation["images"][cam_name], "h w c -> c h w"
)
current_images.append(current_image)
current_image = np.stack(current_images, axis=0)
current_image = (
torch.from_numpy(current_image / 255.0).float().cuda().unsqueeze(0)
)
if random_crop_resize:
logging.info("Random crop resize is used!")
original_size = current_image.shape[-2:]
ratio = 0.95
current_image = current_image[
...,
int(original_size[0] * (1 - ratio) / 2) : int(
original_size[0] * (1 + ratio) / 2
),
int(original_size[1] * (1 - ratio) / 2) : int(
original_size[1] * (1 + ratio) / 2
),
]
current_image = current_image.squeeze(0)
resize_transform = transforms.Resize(original_size, antialias=True)
current_image = resize_transform(current_image)
current_image = current_image.unsqueeze(0)
return current_image
def load_environment(self):
"""
加载环境
"""
if self.real_robot:
self.env = DeviceAloha(self.robot_env)
self.env_max_reward = 0
else:
from sim_env import make_sim_env
self.env = make_sim_env(self.task_name)
self.env_max_reward = self.env.task.max_reward
def get_auto_index(self, checkpoint_dir):
max_idx = 1000
for i in range(max_idx + 1):
if not os.path.isfile(os.path.join(checkpoint_dir, f"qpos_{i}.npy")):
return i
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")
def evaluate(self, checkpoint_name=None):
"""
评估策略
Returns:
tuple: 成功率和平均回报
"""
if checkpoint_name is not None:
self.checkpoint_name = checkpoint_name
set_seed(self._seed) # np与torch的随机种子
self.load_policy_and_stats()
self.load_environment()
query_frequency = self.policy_config["num_queries"]
# 时间聚合时每个时间步只有1个查询
if self.temporal_agg:
query_frequency = 1
num_queries = self.policy_config["num_queries"]
# # 真实机器人时基础延迟为13
# if self.real_robot:
# BASE_DELAY = 13
# # query_frequency -= BASE_DELAY
max_timesteps = int(self.max_timesteps * 1) # may increase for real-world tasks
episode_returns = []
highest_rewards = []
for rollout_id in range(self.num_rollouts):
timestep = self.env.reset()
if self.onscreen_render:
# TODO 画图
pass
if self.temporal_agg:
all_time_actions = torch.zeros(
[max_timesteps, max_timesteps + num_queries, self.state_dim]
).cuda()
qpos_history_raw = np.zeros((max_timesteps, self.state_dim))
rewards = []
with torch.inference_mode():
time_0 = time.time()
DT = 1 / 30
culmulated_delay = 0
for t in range(max_timesteps):
time_1 = time.time()
if self.onscreen_render:
# TODO 显示图像
pass
# process previous timestep to get qpos and image_list
obs = timestep.observation
qpos_numpy = np.array(obs["qpos"])
qpos_history_raw[t] = qpos_numpy
qpos = self.pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
logging.info(f"t{t}")
if t % query_frequency == 0:
current_image = self.get_image_torch(
timestep,
self.camera_names,
random_crop_resize=(
self.config["policy_class"] == "Diffusion"
),
)
if t == 0:
# 网络预热
for _ in range(10):
self.policy(qpos, current_image)
logging.info("Network warm up done")
if self.config["policy_class"] == "ACT":
if t % query_frequency == 0:
if self.vq:
if rollout_id == 0:
for _ in range(10):
vq_sample = self.latent_model.generate(
1, temperature=1, x=None
)
logging.info(
torch.nonzero(vq_sample[0])[:, 1]
.cpu()
.numpy()
)
vq_sample = self.latent_model.generate(
1, temperature=1, x=None
)
all_actions = self.policy(
qpos, current_image, vq_sample=vq_sample
)
else:
all_actions = self.policy(qpos, current_image)
# if self.real_robot:
# all_actions = torch.cat(
# [
# all_actions[:, :-BASE_DELAY, :-2],
# all_actions[:, BASE_DELAY:, -2:],
# ],
# dim=2,
# )
if self.temporal_agg:
all_time_actions[[t], t : t + num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
actions_populated = torch.all(
actions_for_curr_step != 0, axis=1
)
actions_for_curr_step = actions_for_curr_step[
actions_populated
]
k = 0.01
exp_weights = np.exp(
-k * np.arange(len(actions_for_curr_step))
)
exp_weights = exp_weights / exp_weights.sum()
exp_weights = (
torch.from_numpy(exp_weights).cuda().unsqueeze(dim=1)
)
raw_action = (actions_for_curr_step * exp_weights).sum(
dim=0, keepdim=True
)
else:
raw_action = all_actions[:, t % query_frequency]
elif self.config["policy_class"] == "Diffusion":
if t % query_frequency == 0:
all_actions = self.policy(qpos, current_image)
# if self.real_robot:
# all_actions = torch.cat(
# [
# all_actions[:, :-BASE_DELAY, :-2],
# all_actions[:, BASE_DELAY:, -2:],
# ],
# dim=2,
# )
raw_action = all_actions[:, t % query_frequency]
elif self.config["policy_class"] == "CNNMLP":
raw_action = self.policy(qpos, current_image)
all_actions = raw_action.unsqueeze(0)
else:
raise NotImplementedError
### post-process actions
raw_action = raw_action.squeeze(0).cpu().numpy()
action = self.post_process(raw_action)
### step the environment
if self.real_robot:
logging.info(f" action = {action}")
timestep = self.env.step(action)
rewards.append(timestep.reward)
duration = time.time() - time_1
sleep_time = max(0, DT - duration)
time.sleep(sleep_time)
if duration >= DT:
culmulated_delay += duration - DT
logging.warning(
f"Warning: step duration: {duration:.3f} s at step {t} longer than DT: {DT} s, culmulated delay: {culmulated_delay:.3f} s"
)
logging.info(f"Avg fps {max_timesteps / (time.time() - time_0)}")
plt.close()
if self.real_robot:
log_id = self.get_auto_index(self.checkpoint_dir)
np.save(
os.path.join(self.checkpoint_dir, f"qpos_{log_id}.npy"),
qpos_history_raw,
)
plt.figure(figsize=(10, 20))
for i in range(self.state_dim):
plt.subplot(self.state_dim, 1, i + 1)
plt.plot(qpos_history_raw[:, i])
if i != self.state_dim - 1:
plt.xticks([])
plt.tight_layout()
plt.savefig(os.path.join(self.checkpoint_dir, f"qpos_{log_id}.png"))
plt.close()
rewards = np.array(rewards)
episode_return = np.sum(rewards[rewards != None])
episode_returns.append(episode_return)
episode_highest_reward = np.max(rewards)
highest_rewards.append(episode_highest_reward)
logging.info(
f"Rollout {rollout_id}\n{episode_return=}, {episode_highest_reward=}, {self.env_max_reward=}, Success: {episode_highest_reward == self.env_max_reward}"
)
success_rate = np.mean(np.array(highest_rewards) == self.env_max_reward)
avg_return = np.mean(episode_returns)
summary_str = (
f"\nSuccess rate: {success_rate}\nAverage return: {avg_return}\n\n"
)
for r in range(self.env_max_reward + 1):
more_or_equal_r = (np.array(highest_rewards) >= r).sum()
more_or_equal_r_rate = more_or_equal_r / self.num_rollouts
summary_str += f"Reward >= {r}: {more_or_equal_r}/{self.num_rollouts} = {more_or_equal_r_rate * 100}%\n"
logging.info(summary_str)
result_file_name = "result_" + self.checkpoint_name.split(".")[0] + ".txt"
with open(os.path.join(self.checkpoint_dir, result_file_name), "w") as f:
f.write(summary_str)
f.write(repr(episode_returns))
f.write("\n\n")
f.write(repr(highest_rewards))
return success_rate, avg_return
class DeviceAloha:
def __init__(self, aloha_config):
"""
初始化设备
Args:
device_name (str): 设备名称
"""
config_left_arm = aloha_config["rm_left_arm"]
config_right_arm = aloha_config["rm_right_arm"]
config_head_camera = aloha_config["head_camera"]
config_bottom_camera = aloha_config["bottom_camera"]
config_left_camera = aloha_config["left_camera"]
config_right_camera = aloha_config["right_camera"]
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
self.arm_axis = aloha_config["arm_axis"]
self.arm_left = RmArm(config_left_arm)
self.arm_right = RmArm(config_right_arm)
self.camera_left = RealSenseCamera(config_head_camera, False)
self.camera_right = RealSenseCamera(config_bottom_camera, False)
self.camera_bottom = RealSenseCamera(config_left_camera, False)
self.camera_top = RealSenseCamera(config_right_camera, False)
self.camera_left.start_camera()
self.camera_right.start_camera()
self.camera_bottom.start_camera()
self.camera_top.start_camera()
def close(self):
"""
关闭摄像头
"""
self.camera_left.close()
self.camera_right.close()
self.camera_bottom.close()
self.camera_top.close()
def get_qps(self):
"""
获取关节角度
Returns:
np.array: 关节角度
"""
left_slave_arm_angle = self.arm_left.get_joint_angle()
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
right_slave_arm_angle = self.arm_right.get_joint_angle()
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
def get_qvel(self):
"""
获取关节速度
Returns:
np.array: 关节速度
"""
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
def get_effort(self):
"""
获取关节力
Returns:
np.array: 关节力
"""
left_slave_arm_effort = self.arm_left.get_joint_effort()
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
right_slave_arm_effort = self.arm_right.get_joint_effort()
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
def get_images(self):
"""
获取图像
Returns:
dict: 图像字典
"""
self.top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
self.bottom_image, _, _, _ = self.camera_bottom.read_frame(
True, False, False, False
)
self.left_image, _, _, _ = self.camera_left.read_frame(
True, False, False, False
)
self.right_image, _, _, _ = self.camera_right.read_frame(
True, False, False, False
)
return {
"cam_high": self.top_image,
"cam_low": self.bottom_image,
"cam_left": self.left_image,
"cam_right": self.right_image,
}
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qps()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
return obs
def reset(self):
logging.info("Resetting the environment")
self.arm_left.set_joint_position(self.init_left_arm_angle[0:self.arm_axis])
self.arm_right.set_joint_position(self.init_right_arm_angle[0:self.arm_axis])
self.arm_left.set_gripper_position(0)
self.arm_right.set_gripper_position(0)
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=0,
discount=None,
observation=self.get_observation(),
)
def step(self, target_angle):
self.arm_left.set_joint_canfd_position(target_angle[0:self.arm_axis])
self.arm_right.set_joint_canfd_position(target_angle[self.arm_axis+1:self.arm_axis*2+1])
self.arm_left.set_gripper_position(target_angle[self.arm_axis])
self.arm_right.set_gripper_position(target_angle[(self.arm_axis*2 + 1)])
return dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=0,
discount=None,
observation=self.get_observation(),
)
if __name__ == "__main__":
# with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
# config = yaml.safe_load(f)
# aloha_config = config["robot_env"]
# device = DeviceAloha(aloha_config)
# device.reset()
# while True:
# init_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
# time_step = time.time()
# timestep = device.step(init_angle)
# logging.info(f"Time: {time.time() - time_step}")
# obs = timestep.observation
with open("/home/wang/project/shadow_rm_act/config/config.yaml", "r") as f:
config = yaml.safe_load(f)
# logging.info(f"Config: {config}")
evaluator = RmActEvaluator(config)
success_rate, avg_return = evaluator.evaluate()

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,153 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
import torch
import torchvision
from torch import nn
from typing import Dict, List
import torch.nn.functional as F
from .position_encoding import build_position_encoding
from torchvision.models import ResNet18_Weights
from torchvision.models._utils import IntermediateLayerGetter
from shadow_act.utils.misc import NestedTensor, is_main_process
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
class BackboneBase(nn.Module):
def __init__(
self,
backbone: nn.Module,
train_backbone: bool,
num_channels: int,
return_interm_layers: bool,
):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {"layer4": "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels
def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out
class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(
self,
name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool,
):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
weights=ResNet18_Weights.IMAGENET1K_V1 if is_main_process() else None,
norm_layer=FrozenBatchNorm2d,
)
# backbone = getattr(torchvision.models, name)(
# replace_stride_with_dilation=[False, False, dilation],
# pretrained=is_main_process(),
# norm_layer=FrozenBatchNorm2d,
# ) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ("resnet18", "resnet34") else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)
def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))
return out, pos
def build_backbone(
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
):
position_embedding = build_position_encoding(
hidden_dim=hidden_dim, position_embedding_type=position_embedding_type
)
train_backbone = lr_backbone > 0
return_interm_layers = masks
backbone = Backbone(backbone, train_backbone, return_interm_layers, dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model

View File

@@ -0,0 +1,436 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
import torch.nn.functional as F
from shadow_act.models.transformer import Transformer
from .backbone import build_backbone
from .transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer
import numpy as np
def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps
def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (hid_j // 2) / d_hid)
for hid_j in range(d_hid)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos_i) for pos_i in range(n_position)]
)
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
class DETRVAE(nn.Module):
"""This is the DETR module that performs object detection"""
def __init__(
self,
backbones,
transformer,
encoder,
state_dim,
num_queries,
camera_names,
vq,
vq_class,
vq_dim,
action_dim,
):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_queries = num_queries
self.camera_names = camera_names
self.transformer = transformer
self.encoder = encoder
self.vq, self.vq_class, self.vq_dim = vq, vq_class, vq_dim
self.state_dim, self.action_dim = state_dim, action_dim
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.is_pad_head = nn.Linear(hidden_dim, 1)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(
backbones[0].num_channels, hidden_dim, kernel_size=1
)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
# input_dim = 14 + 7 # robot_state + env_state
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.input_proj_env_state = nn.Linear(7, hidden_dim)
self.pos = torch.nn.Embedding(2, hidden_dim)
self.backbones = None
# encoder extra parameters
self.latent_dim = 32 # final size of latent z # TODO tune
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_action_proj = nn.Linear(
action_dim, hidden_dim
) # project action to embedding
self.encoder_joint_proj = nn.Linear(
action_dim, hidden_dim
) # project qpos to embedding
if self.vq:
self.latent_proj = nn.Linear(hidden_dim, self.vq_class * self.vq_dim)
else:
self.latent_proj = nn.Linear(
hidden_dim, self.latent_dim * 2
) # project hidden state to latent std, var
self.register_buffer(
"pos_table", get_sinusoid_encoding_table(1 + 1 + num_queries, hidden_dim)
) # [CLS], qpos, a_seq
# decoder extra parameters
if self.vq:
self.latent_out_proj = nn.Linear(self.vq_class * self.vq_dim, hidden_dim)
else:
self.latent_out_proj = nn.Linear(
self.latent_dim, hidden_dim
) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(
2, hidden_dim
) # learned position embedding for proprio and latent
def encode(self, qpos, actions=None, is_pad=None, vq_sample=None):
bs, _ = qpos.shape
if self.encoder is None:
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(
qpos.device
)
latent_input = self.latent_out_proj(latent_sample)
probs = binaries = mu = logvar = None
else:
# cvae encoder
is_training = actions is not None # train or val
### Obtain latent z from action sequence
if is_training:
# project action sequence to embedding dim, and concat with a CLS token
action_embed = self.encoder_action_proj(
actions
) # (bs, seq, hidden_dim)
qpos_embed = self.encoder_joint_proj(qpos) # (bs, hidden_dim)
qpos_embed = torch.unsqueeze(qpos_embed, axis=1) # (bs, 1, hidden_dim)
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(
bs, 1, 1
) # (bs, 1, hidden_dim)
encoder_input = torch.cat(
[cls_embed, qpos_embed, action_embed], axis=1
) # (bs, seq+1, hidden_dim)
encoder_input = encoder_input.permute(
1, 0, 2
) # (seq+1, bs, hidden_dim)
# do not mask cls token
cls_joint_is_pad = torch.full((bs, 2), False).to(
qpos.device
) # False: not a padding
is_pad = torch.cat([cls_joint_is_pad, is_pad], axis=1) # (bs, seq+1)
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+1, 1, hidden_dim)
# query model
encoder_output = self.encoder(
encoder_input, pos=pos_embed, src_key_padding_mask=is_pad
)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
if self.vq:
logits = latent_info.reshape(
[*latent_info.shape[:-1], self.vq_class, self.vq_dim]
)
probs = torch.softmax(logits, dim=-1)
binaries = (
F.one_hot(
torch.multinomial(probs.view(-1, self.vq_dim), 1).squeeze(
-1
),
self.vq_dim,
)
.view(-1, self.vq_class, self.vq_dim)
.float()
)
binaries_flat = binaries.view(-1, self.vq_class * self.vq_dim)
probs_flat = probs.view(-1, self.vq_class * self.vq_dim)
straigt_through = binaries_flat - probs_flat.detach() + probs_flat
latent_input = self.latent_out_proj(straigt_through)
mu = logvar = None
else:
probs = binaries = None
mu = latent_info[:, : self.latent_dim]
logvar = latent_info[:, self.latent_dim :]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = binaries = probs = None
if self.vq:
latent_input = self.latent_out_proj(
vq_sample.view(-1, self.vq_class * self.vq_dim)
)
else:
latent_sample = torch.zeros(
[bs, self.latent_dim], dtype=torch.float32
).to(qpos.device)
latent_input = self.latent_out_proj(latent_sample)
return latent_input, probs, binaries, mu, logvar
def forward(
self, qpos, image, env_state, actions=None, is_pad=None, vq_sample=None
):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
latent_input, probs, binaries, mu, logvar = self.encode(
qpos, actions, is_pad, vq_sample
)
# cvae decoder
if self.backbones is not None:
# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id, cam_name in enumerate(self.camera_names):
# TODO: fix this error
features, pos = self.backbones[0](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0]
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)
# proprioception features
proprio_input = self.input_proj_robot_state(qpos)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3)
pos = torch.cat(all_cam_pos, axis=3)
hs = self.transformer(
src,
None,
self.query_embed.weight,
pos,
latent_input,
proprio_input,
self.additional_pos_embed.weight,
)[0]
else:
qpos = self.input_proj_robot_state(qpos)
env_state = self.input_proj_env_state(env_state)
transformer_input = torch.cat([qpos, env_state], axis=1) # seq length = 2
hs = self.transformer(
transformer_input, None, self.query_embed.weight, self.pos.weight
)[0]
a_hat = self.action_head(hs)
is_pad_hat = self.is_pad_head(hs)
return a_hat, is_pad_hat, [mu, logvar], probs, binaries
class CNNMLP(nn.Module):
def __init__(self, backbones, state_dim, camera_names):
"""Initializes the model.
Parameters:
backbones: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
state_dim: robot state dimension of the environment
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
DETR can detect in a single image. For COCO, we recommend 100 queries.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.camera_names = camera_names
self.action_head = nn.Linear(1000, state_dim) # TODO add more
if backbones is not None:
self.backbones = nn.ModuleList(backbones)
backbone_down_projs = []
for backbone in backbones:
down_proj = nn.Sequential(
nn.Conv2d(backbone.num_channels, 128, kernel_size=5),
nn.Conv2d(128, 64, kernel_size=5),
nn.Conv2d(64, 32, kernel_size=5),
)
backbone_down_projs.append(down_proj)
self.backbone_down_projs = nn.ModuleList(backbone_down_projs)
mlp_in_dim = 768 * len(backbones) + state_dim
self.mlp = mlp(
input_dim=mlp_in_dim,
hidden_dim=1024,
output_dim=self.action_dim,
hidden_depth=2,
)
else:
raise NotImplementedError
def forward(self, qpos, image, env_state, actions=None):
"""
qpos: batch, qpos_dim
image: batch, num_cam, channel, height, width
env_state: None
actions: batch, seq, action_dim
"""
is_training = actions is not None # train or val
bs, _ = qpos.shape
# Image observation features and position embeddings
all_cam_features = []
for cam_id, cam_name in enumerate(self.camera_names):
features, pos = self.backbones[cam_id](image[:, cam_id])
features = features[0] # take the last layer feature
pos = pos[0] # not used
all_cam_features.append(self.backbone_down_projs[cam_id](features))
# flatten everything
flattened_features = []
for cam_feature in all_cam_features:
flattened_features.append(cam_feature.reshape([bs, -1]))
flattened_features = torch.cat(flattened_features, axis=1) # 768 each
features = torch.cat([flattened_features, qpos], axis=1) # qpos: 14
a_hat = self.mlp(features)
return a_hat
def mlp(input_dim, hidden_dim, output_dim, hidden_depth):
if hidden_depth == 0:
mods = [nn.Linear(input_dim, output_dim)]
else:
mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)]
for i in range(hidden_depth - 1):
mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)]
mods.append(nn.Linear(hidden_dim, output_dim))
trunk = nn.Sequential(*mods)
return trunk
def build_encoder(
hidden_dim, # 256
dropout, # 0.1
nheads, # 8
dim_feedforward,
num_encoder_layers, # 4 # TODO shared with VAE decoder
normalize_before, # False
):
activation = "relu"
encoder_layer = TransformerEncoderLayer(
hidden_dim, nheads, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
return encoder
def build_vae(
hidden_dim,
state_dim,
position_embedding_type,
lr_backbone,
masks,
backbone,
dilation,
dropout,
nheads,
dim_feedforward,
enc_layers,
dec_layers,
pre_norm,
num_queries,
camera_names,
vq,
vq_class,
vq_dim,
action_dim,
no_encoder,
):
# TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
backbone = build_backbone(
hidden_dim, position_embedding_type, lr_backbone, masks, backbone, dilation
)
backbones.append(backbone)
transformer = build_transformer(
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
)
if no_encoder:
encoder = None
else:
encoder = build_encoder(
hidden_dim,
dropout,
nheads,
dim_feedforward,
enc_layers,
pre_norm,
)
model = DETRVAE(
backbones,
transformer,
encoder,
state_dim,
num_queries,
camera_names,
vq,
vq_class,
vq_dim,
action_dim,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
return model
# TODO
def build_cnnmlp(args):
state_dim = 14 # TODO hardcode
# From state
# backbone = None # from state for now, no need for conv nets
# From image
backbones = []
for _ in args.camera_names:
backbone = build_backbone(args)
backbones.append(backbone)
model = CNNMLP(
backbones,
state_dim=state_dim,
camera_names=args.camera_names,
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of parameters: %.2fM" % (n_parameters / 1e6,))
return model

View File

@@ -0,0 +1,122 @@
#!/usr/bin/env python3
import torch.nn as nn
from torch.nn import functional as F
import torch
DROPOUT_RATE = 0.1 # 定义 dropout 率
# 定义一个因果变压器块
class Causal_Transformer_Block(nn.Module):
def __init__(self, seq_len, latent_dim, num_head) -> None:
"""
初始化因果变压器块
Args:
seq_len (int): 序列长度
latent_dim (int): 潜在维度
num_head (int): 注意力头的数量
"""
super().__init__()
self.num_head = num_head
self.latent_dim = latent_dim
self.ln_1 = nn.LayerNorm(latent_dim) # 层归一化
self.attn = nn.MultiheadAttention(latent_dim, num_head, dropout=DROPOUT_RATE, batch_first=True) # 多头注意力机制
self.ln_2 = nn.LayerNorm(latent_dim) # 层归一化
self.mlp = nn.Sequential(
nn.Linear(latent_dim, 4 * latent_dim), # 全连接层
nn.GELU(), # GELU 激活函数
nn.Linear(4 * latent_dim, latent_dim), # 全连接层
nn.Dropout(DROPOUT_RATE), # Dropout
)
# self.register_buffer("attn_mask", torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()) # 注册注意力掩码
def forward(self, x):
"""
前向传播
Args:
x (torch.Tensor): 输入张量
Returns:
torch.Tensor: 输出张量
"""
# 创建上三角掩码,防止信息泄露
attn_mask = torch.triu(torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool), diagonal=1)
x = self.ln_1(x) # 层归一化
x = x + self.attn(x, x, x, attn_mask=attn_mask)[0] # 加上注意力输出
x = self.ln_2(x) # 层归一化
x = x + self.mlp(x) # 加上 MLP 输出
return x
# 使用自注意力机制而不是 RNN 来建模潜在空间序列
class Latent_Model_Transformer(nn.Module):
def __init__(self, input_dim, output_dim, seq_len, latent_dim=256, num_head=8, num_layer=3) -> None:
"""
初始化潜在模型变压器
Args:
input_dim (int): 输入维度
output_dim (int): 输出维度
seq_len (int): 序列长度
latent_dim (int, optional): 潜在维度,默认值为 256
num_head (int, optional): 注意力头的数量,默认值为 8
num_layer (int, optional): 变压器层的数量,默认值为 3
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.seq_len = seq_len
self.latent_dim = latent_dim
self.num_head = num_head
self.num_layer = num_layer
self.input_layer = nn.Linear(input_dim, latent_dim) # 输入层
self.weight_pos_embed = nn.Embedding(seq_len, latent_dim) # 位置嵌入
self.attention_blocks = nn.Sequential(
nn.Dropout(DROPOUT_RATE), # Dropout
*[Causal_Transformer_Block(seq_len, latent_dim, num_head) for _ in range(num_layer)], # 多个因果变压器块
nn.LayerNorm(latent_dim) # 层归一化
)
self.output_layer = nn.Linear(latent_dim, output_dim) # 输出层
def forward(self, x):
"""
前向传播
Args:
x (torch.Tensor): 输入张量
Returns:
torch.Tensor: 输出张量
"""
x = self.input_layer(x) # 输入层
x = x + self.weight_pos_embed(torch.arange(x.shape[1], device=x.device)) # 加上位置嵌入
x = self.attention_blocks(x) # 通过注意力块
logits = self.output_layer(x) # 输出层
return logits
@torch.no_grad()
def generate(self, n, temperature=0.1, x=None):
"""
生成序列
Args:
n (int): 生成序列的数量
temperature (float, optional): 采样温度,默认值为 0.1
x (torch.Tensor, optional): 初始输入张量,默认值为 None
Returns:
torch.Tensor: 生成的序列
"""
if x is None:
x = torch.zeros((n, 1, self.input_dim), device=self.weight_pos_embed.weight.device) # 初始化输入
for i in range(self.seq_len):
logits = self.forward(x)[:, -1] # 获取最后一个时间步的输出
probs = torch.softmax(logits / temperature, dim=-1) # 计算概率分布
samples = torch.multinomial(probs, num_samples=1)[..., 0] # 从概率分布中采样
samples_one_hot = F.one_hot(samples.long(), num_classes=self.output_dim).float() # 转为 one-hot 编码
x = torch.cat([x, samples_one_hot[:, None, :]], dim=1) # 将新采样的结果添加到输入中
return x[:, 1:, :] # 返回生成的序列(去掉初始的零输入)

View File

@@ -0,0 +1,91 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Various positional encodings for the transformer.
"""
import math
import torch
from torch import nn
from shadow_act.utils.misc import NestedTensor
class PositionEmbeddingSine(nn.Module):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, tensor):
x = tensor
# mask = tensor_list.mask
# assert mask is not None
# not_mask = ~mask
not_mask = torch.ones_like(x[0, [0]])
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)
if self.normalize:
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
return pos
class PositionEmbeddingLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.row_embed.weight)
nn.init.uniform_(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = torch.arange(w, device=x.device)
j = torch.arange(h, device=x.device)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = torch.cat([
x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1),
], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
return pos
def build_position_encoding(hidden_dim, position_embedding_type):
N_steps = hidden_dim // 2
if position_embedding_type in ('v2', 'sine'):
# TODO find a better way of exposing other arguments
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif position_embedding_type in ('v3', 'learned'):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {position_embedding_type}")
return position_embedding

View File

@@ -0,0 +1,424 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR Transformer class.
Copy-paste from torch.nn.Transformer with modifications:
* positional encodings are passed in MHattention
* extra LN at the end of encoder is removed
* decoder returns a stack of activations from all decoding layers
"""
import copy
from typing import Optional, List
import torch
import torch.nn.functional as F
from torch import nn, Tensor
class Transformer(nn.Module):
def __init__(
self,
d_model=512,
nhead=8,
num_encoder_layers=6,
num_decoder_layers=6,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
return_intermediate_dec=False,
):
super().__init__()
encoder_layer = TransformerEncoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
self.encoder = TransformerEncoder(
encoder_layer, num_encoder_layers, encoder_norm
)
decoder_layer = TransformerDecoderLayer(
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
)
decoder_norm = nn.LayerNorm(d_model)
self.decoder = TransformerDecoder(
decoder_layer,
num_decoder_layers,
decoder_norm,
return_intermediate=return_intermediate_dec,
)
self._reset_parameters()
self.d_model = d_model
self.nhead = nhead
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(
self,
src,
mask,
query_embed,
pos_embed,
latent_input=None,
proprio_input=None,
additional_pos_embed=None,
):
# TODO flatten only when input has H and W
if len(src.shape) == 4: # has H and W
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
# mask = mask.flatten(1)
additional_pos_embed = additional_pos_embed.unsqueeze(1).repeat(
1, bs, 1
) # seq, bs, dim
pos_embed = torch.cat([additional_pos_embed, pos_embed], axis=0)
addition_input = torch.stack([latent_input, proprio_input], axis=0)
src = torch.cat([addition_input, src], axis=0)
else:
assert len(src.shape) == 3
# flatten NxHWxC to HWxNxC
bs, hw, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.unsqueeze(1).repeat(1, bs, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(
tgt,
memory,
memory_key_padding_mask=mask,
pos=pos_embed,
query_pos=query_embed,
)
hs = hs.transpose(1, 2)
return hs
class TransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers, norm=None):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
src,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
output = src
for layer in self.layers:
output = layer(
output,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
pos=pos,
)
if self.norm is not None:
output = self.norm(output)
return output
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
self.return_intermediate = return_intermediate
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
output = tgt
intermediate = []
for layer in self.layers:
output = layer(
output,
memory,
tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos,
query_pos=query_pos,
)
if self.return_intermediate:
intermediate.append(self.norm(output))
if self.norm is not None:
output = self.norm(output)
if self.return_intermediate:
intermediate.pop()
intermediate.append(output)
if self.return_intermediate:
return torch.stack(intermediate)
return output.unsqueeze(0)
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(src, pos)
src2 = self.self_attn(
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
q = k = self.with_pos_embed(tgt, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
def build_transformer(
hidden_dim, dropout, nheads, dim_feedforward, enc_layers, dec_layers, pre_norm
):
return Transformer(
d_model=hidden_dim,
dropout=dropout,
nhead=nheads,
dim_feedforward=dim_feedforward,
num_encoder_layers=enc_layers,
num_decoder_layers=dec_layers,
normalize_before=pre_norm,
return_intermediate_dec=True,
)
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,522 @@
import torch
import logging
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import torchvision.transforms as transforms
from shadow_act.models.detr_vae import build_vae, build_cnnmlp
# from diffusers.training_utils import EMAModel
# from robomimic.models.base_nets import ResNet18Conv, SpatialSoftmax
# from robomimic.algo.diffusion_policy import replace_bn_with_gn, ConditionalUnet1D
# from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
# from diffusers.schedulers.scheduling_ddim import DDIMScheduler
# 配置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# TODO: 重构DiffusionPolicy类
class DiffusionPolicy(nn.Module):
def __init__(self, args_override):
"""
初始化DiffusionPolicy类
Args:
args_override (dict): 参数覆盖字典
"""
super().__init__()
self.camera_names = args_override["camera_names"]
self.observation_horizon = args_override["observation_horizon"]
self.action_horizon = args_override["action_horizon"]
self.prediction_horizon = args_override["prediction_horizon"]
self.num_inference_timesteps = args_override["num_inference_timesteps"]
self.ema_power = args_override["ema_power"]
self.lr = args_override["lr"]
self.weight_decay = 0
self.num_kp = 32
self.feature_dimension = 64
self.ac_dim = args_override["action_dim"]
self.obs_dim = self.feature_dimension * len(self.camera_names) + 14
backbones = []
pools = []
linears = []
for _ in self.camera_names:
backbones.append(
ResNet18Conv(input_channel=3, pretrained=False, input_coord_conv=False)
)
pools.append(
SpatialSoftmax(
input_shape=[512, 15, 20],
num_kp=self.num_kp,
temperature=1.0,
learnable_temperature=False,
noise_std=0.0,
)
)
linears.append(
torch.nn.Linear(int(np.prod([self.num_kp, 2])), self.feature_dimension)
)
backbones = nn.ModuleList(backbones)
pools = nn.ModuleList(pools)
linears = nn.ModuleList(linears)
backbones = replace_bn_with_gn(backbones)
noise_pred_net = ConditionalUnet1D(
input_dim=self.ac_dim,
global_cond_dim=self.obs_dim * self.observation_horizon,
)
nets = nn.ModuleDict(
{
"policy": nn.ModuleDict(
{
"backbones": backbones,
"pools": pools,
"linears": linears,
"noise_pred_net": noise_pred_net,
}
)
}
)
nets = nets.float().cuda()
ENABLE_EMA = True
if ENABLE_EMA:
ema = EMAModel(model=nets, power=self.ema_power)
else:
ema = None
self.nets = nets
self.ema = ema
# 设置噪声调度器
self.noise_scheduler = DDIMScheduler(
num_train_timesteps=50,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
set_alpha_to_one=True,
steps_offset=0,
prediction_type="epsilon",
)
n_parameters = sum(p.numel() for p in self.parameters())
logger.info("number of parameters: %.2fM", n_parameters / 1e6)
def configure_optimizers(self):
"""
配置优化器
Returns:
optimizer: 配置的优化器
"""
optimizer = torch.optim.AdamW(
self.nets.parameters(), lr=self.lr, weight_decay=self.weight_decay
)
return optimizer
def __call__(self, qpos, image, actions=None, is_pad=None):
"""
前向传播函数
Args:
qpos (torch.Tensor): 位置张量
image (torch.Tensor): 图像张量
actions (torch.Tensor, optional): 动作张量
is_pad (torch.Tensor, optional): 填充张量
Returns:
dict: 损失字典(训练时)
torch.Tensor: 动作张量(推理时)
"""
B = qpos.shape[0]
if actions is not None: # 训练时
nets = self.nets
all_features = []
for cam_id in range(len(self.camera_names)):
cam_image = image[:, cam_id]
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
pool_features = nets["policy"]["pools"][cam_id](cam_features)
pool_features = torch.flatten(pool_features, start_dim=1)
out_features = nets["policy"]["linears"][cam_id](pool_features)
all_features.append(out_features)
obs_cond = torch.cat(all_features + [qpos], dim=1)
# 为动作添加噪声
noise = torch.randn(actions.shape, device=obs_cond.device)
# 为每个数据点采样一个扩散迭代
timesteps = torch.randint(
0,
self.noise_scheduler.config.num_train_timesteps,
(B,),
device=obs_cond.device,
).long()
# 根据每个扩散迭代的噪声幅度向干净动作添加噪声
noisy_actions = self.noise_scheduler.add_noise(actions, noise, timesteps)
# 预测噪声残差
noise_pred = nets["policy"]["noise_pred_net"](
noisy_actions, timesteps, global_cond=obs_cond
)
# L2损失
all_l2 = F.mse_loss(noise_pred, noise, reduction="none")
loss = (all_l2 * ~is_pad.unsqueeze(-1)).mean()
loss_dict = {}
loss_dict["l2_loss"] = loss
loss_dict["loss"] = loss
if self.training and self.ema is not None:
self.ema.step(nets)
return loss_dict
else: # 推理时
To = self.observation_horizon
Ta = self.action_horizon
Tp = self.prediction_horizon
action_dim = self.ac_dim
nets = self.nets
if self.ema is not None:
nets = self.ema.averaged_model
all_features = []
for cam_id in range(len(self.camera_names)):
cam_image = image[:, cam_id]
cam_features = nets["policy"]["backbones"][cam_id](cam_image)
pool_features = nets["policy"]["pools"][cam_id](cam_features)
pool_features = torch.flatten(pool_features, start_dim=1)
out_features = nets["policy"]["linears"][cam_id](pool_features)
all_features.append(out_features)
obs_cond = torch.cat(all_features + [qpos], dim=1)
# 从高斯噪声初始化动作
noisy_action = torch.randn((B, Tp, action_dim), device=obs_cond.device)
naction = noisy_action
# 初始化调度器
self.noise_scheduler.set_timesteps(self.num_inference_timesteps)
for k in self.noise_scheduler.timesteps:
# 预测噪声
noise_pred = nets["policy"]["noise_pred_net"](
sample=naction, timestep=k, global_cond=obs_cond
)
# 逆扩散步骤(去除噪声)
naction = self.noise_scheduler.step(
model_output=noise_pred, timestep=k, sample=naction
).prev_sample
return naction
def serialize(self):
"""
序列化模型
Returns:
dict: 模型状态字典
"""
return {
"nets": self.nets.state_dict(),
"ema": (
self.ema.averaged_model.state_dict() if self.ema is not None else None
),
}
def deserialize(self, model_dict):
"""
反序列化模型
Args:
model_dict (dict): 模型状态字典
Returns:
status: 加载状态
"""
status = self.nets.load_state_dict(model_dict["nets"])
logger.info("Loaded model")
if model_dict.get("ema", None) is not None:
logger.info("Loaded EMA")
status_ema = self.ema.averaged_model.load_state_dict(model_dict["ema"])
status = [status, status_ema]
return status
class ACTPolicy(nn.Module):
def __init__(self, act_config):
"""
初始化ACTPolicy类
Args:
args_override (dict): 参数覆盖字典
"""
super().__init__()
lr_backbone = act_config["lr_backbone"]
vq = act_config["vq"]
lr = act_config["lr"]
weight_decay = act_config["weight_decay"]
model = build_vae(
act_config["hidden_dim"],
act_config["state_dim"],
act_config["position_embedding"],
lr_backbone,
act_config["masks"],
act_config["backbone"],
act_config["dilation"],
act_config["dropout"],
act_config["nheads"],
act_config["dim_feedforward"],
act_config["enc_layers"],
act_config["dec_layers"],
act_config["pre_norm"],
act_config["num_queries"],
act_config["camera_names"],
vq,
act_config["vq_class"],
act_config["vq_dim"],
act_config["action_dim"],
act_config["no_encoder"],
)
model.cuda()
param_dicts = [
{
"params": [
p
for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad
]
},
{
"params": [
p
for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad
],
"lr": lr_backbone,
},
]
self.optimizer = torch.optim.AdamW(
param_dicts, lr=lr, weight_decay=weight_decay
)
self.model = model # CVAE解码器
self.kl_weight = act_config["kl_weight"]
self.vq = vq
logger.info(f"KL Weight {self.kl_weight}")
def __call__(self, qpos, image, actions=None, is_pad=None, vq_sample=None):
"""
前向传播函数
Args:
qpos (torch.Tensor): 角度张量
image (torch.Tensor): 图像张量
actions (torch.Tensor, optional): 动作张量
is_pad (torch.Tensor, optional): 填充张量
vq_sample (torch.Tensor, optional): VQ样本
Returns:
dict: 损失字典(训练时)
torch.Tensor: 动作张量(推理时)
"""
env_state = None
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
image = normalize(image)
if actions is not None: # 训练时
actions = actions[:, : self.model.num_queries]
is_pad = is_pad[:, : self.model.num_queries]
loss_dict = dict()
a_hat, is_pad_hat, (mu, logvar), probs, binaries = self.model(
qpos, image, env_state, actions, is_pad
)
if self.vq or self.model.encoder is None:
total_kld = [torch.tensor(0.0)]
else:
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
if self.vq:
loss_dict["vq_discrepancy"] = F.l1_loss(
probs, binaries, reduction="mean"
)
all_l1 = F.l1_loss(actions, a_hat, reduction="none")
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict["l1"] = l1
loss_dict["kl"] = total_kld[0]
loss_dict["loss"] = loss_dict["l1"] + loss_dict["kl"] * self.kl_weight
return loss_dict
else: # 推理时
a_hat, _, (_, _), _, _ = self.model(
qpos, image, env_state, vq_sample=vq_sample
) # no action, sample from prior
return a_hat
def configure_optimizers(self):
"""
配置优化器
Returns:
optimizer: 配置的优化器
"""
return self.optimizer
@torch.no_grad()
def vq_encode(self, qpos, actions, is_pad):
"""
VQ编码
Args:
qpos (torch.Tensor): 位置张量
actions (torch.Tensor): 动作张量
is_pad (torch.Tensor): 填充张量
Returns:
torch.Tensor: 二进制编码
"""
actions = actions[:, : self.model.num_queries]
is_pad = is_pad[:, : self.model.num_queries]
_, _, binaries, _, _ = self.model.encode(qpos, actions, is_pad)
return binaries
def serialize(self):
"""
序列化模型
Returns:
dict: 模型状态字典
"""
return self.state_dict()
def deserialize(self, model_dict):
"""
反序列化模型
Args:
model_dict (dict): 模型状态字典
Returns:
status: 加载状态
"""
return self.load_state_dict(model_dict)
class CNNMLPPolicy(nn.Module):
def __init__(self, args_override):
"""
初始化CNNMLPPolicy类
Args:
args_override (dict): 参数覆盖字典
"""
super().__init__()
# parser = argparse.ArgumentParser(
# "DETR training and evaluation script", parents=[get_args_parser()]
# )
# args = parser.parse_args()
# for k, v in args_override.items():
# setattr(args, k, v)
model = build_cnnmlp(args_override)
model.cuda()
param_dicts = [
{
"params": [
p
for n, p in model.named_parameters()
if "backbone" not in n and p.requires_grad
]
},
{
"params": [
p
for n, p in model.named_parameters()
if "backbone" in n and p.requires_grad
],
"lr": args_override.lr_backbone,
},
]
self.model = model # 解码器
self.optimizer = torch.optim.AdamW(
param_dicts, lr=args_override.lr, weight_decay=args_override.weight_decay
)
def __call__(self, qpos, image, actions=None, is_pad=None):
"""
前向传播函数
Args:
qpos (torch.Tensor): 位置张量
image (torch.Tensor): 图像张量
actions (torch.Tensor, optional): 动作张量
is_pad (torch.Tensor, optional): 填充张量
Returns:
dict: 损失字典(训练时)
torch.Tensor: 动作张量(推理时)
"""
env_state = None
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
image = normalize(image)
if actions is not None: # 训练时
actions = actions[:, 0]
a_hat = self.model(qpos, image, env_state, actions)
mse = F.mse_loss(actions, a_hat)
loss_dict = dict()
loss_dict["mse"] = mse
loss_dict["loss"] = loss_dict["mse"]
return loss_dict
else: # 推理时
a_hat = self.model(qpos, image, env_state) # 无动作,从先验中采样
return a_hat
def configure_optimizers(self):
"""
配置优化器
Returns:
optimizer: 配置的优化器
"""
return self.optimizer
def kl_divergence(mu, logvar):
"""
计算KL散度
Args:
mu (torch.Tensor): 均值张量
logvar (torch.Tensor): 对数方差张量
Returns:
tuple: 总KL散度维度KL散度均值KL散度
"""
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld

View File

@@ -0,0 +1,245 @@
import os
import yaml
import pickle
import torch
# import wandb
import logging
import numpy as np
from tqdm import tqdm
from copy import deepcopy
from itertools import repeat
from shadow_act.utils.utils import (
set_seed,
load_data,
compute_dict_mean,
)
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
from shadow_act.eval.rm_act_eval import RmActEvaluator
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
class RmActTrainer:
def __init__(self, config):
"""
初始化训练器,设置随机种子,加载数据,保存数据统计信息。
"""
self._config = config
self._num_steps = config["num_steps"]
self._ckpt_dir = config["checkpoint_dir"]
self._state_dim = config["state_dim"]
self._real_robot = config["real_robot"]
self._policy_class = config["policy_class"]
self._onscreen_render = config["onscreen_render"]
self._policy_config = config["policy_config"]
self._camera_names = config["camera_names"]
self._max_timesteps = config["episode_len"]
self._task_name = config["task_name"]
self._temporal_agg = config["temporal_agg"]
self._onscreen_cam = "angle"
self._vq = config["policy_config"]["vq"]
self._batch_size = config["batch_size"]
self._seed = config["seed"]
self._eval_every = config["eval_every"]
self._validate_every = config["validate_every"]
self._save_every = config["save_every"]
self._load_pretrain = config["load_pretrain"]
self._resume_ckpt_path = config["resume_ckpt_path"]
if config["name_filter"] is None:
name_filter = lambda n : True
else:
name_filter = config["name_filter"]
self._eval = RmActEvaluator(config, True, 50)
# 加载数据
self._train_dataloader, self._val_dataloader, self._stats, _ = load_data(
config["dataset_dir"],
name_filter,
self._camera_names,
self._batch_size,
self._batch_size,
config["chunk_size"],
config["skip_mirrored_data"],
self._load_pretrain,
self._policy_class,
config["stats_dir"],
config["sample_weights"],
config["train_ratio"],
)
# 保存数据统计信息
stats_path = os.path.join(self._ckpt_dir, "dataset_stats.pkl")
with open(stats_path, "wb") as f:
pickle.dump(self._stats, f)
expr_name = self._ckpt_dir.split("/")[-1]
# wandb.init(
# project="train_rm_aloha", reinit=True, entity="train_rm_aloha", name=expr_name
# )
def _make_policy(self):
"""
根据策略类和配置创建策略对象
"""
if self._policy_class == "ACT":
return ACTPolicy(self._policy_config)
elif self._policy_class == "CNNMLP":
return CNNMLPPolicy(self._policy_config)
elif self._policy_class == "Diffusion":
return DiffusionPolicy(self._policy_config)
else:
raise NotImplementedError(f"Policy class {self._policy_class} is not implemented")
def _make_optimizer(self):
"""
根据策略类创建优化器
"""
if self._policy_class in ["ACT", "CNNMLP", "Diffusion"]:
return self._policy.configure_optimizers()
else:
raise NotImplementedError
def _forward_pass(self, data):
"""
前向传播,计算损失
"""
image_data, qpos_data, action_data, is_pad = data
try:
image_data, qpos_data, action_data, is_pad = (
image_data.cuda(),
qpos_data.cuda(),
action_data.cuda(),
is_pad.cuda(),
)
except RuntimeError as e:
logging.error(f"CUDA error: {e}")
raise
return self._policy(qpos_data, image_data, action_data, is_pad)
def _repeater(self):
"""
数据加载器的重复器,生成数据
"""
epoch = 0
for loader in repeat(self._train_dataloader):
for data in loader:
yield data
logging.info(f"Epoch {epoch} done")
epoch += 1
def train(self):
"""
训练模型,保存最佳模型
"""
set_seed(self._seed)
self._policy = self._make_policy()
min_val_loss = np.inf
best_ckpt_info = None
# 加载预训练模型
if self._load_pretrain:
try:
loading_status = self._policy.deserialize(
torch.load(
os.path.join(
"/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all",
"policy_step_50000_seed_0.ckpt",
)
)
)
logging.info(f"loaded! {loading_status}")
except FileNotFoundError as e:
logging.error(f"Pretrain model not found: {e}")
except Exception as e:
logging.error(f"Error loading pretrain model: {e}")
# 恢复检查点
if self._resume_ckpt_path is not None:
try:
loading_status = self._policy.deserialize(torch.load(self._resume_ckpt_path))
logging.info(f"Resume policy from: {self._resume_ckpt_path}, Status: {loading_status}")
except FileNotFoundError as e:
logging.error(f"Checkpoint not found: {e}")
except Exception as e:
logging.error(f"Error loading checkpoint: {e}")
self._policy.cuda()
self._optimizer = self._make_optimizer()
train_dataloader = self._repeater() # 重复器
for step in tqdm(range(self._num_steps + 1)):
# 验证模型
if step % self._validate_every != 0:
continue
logging.info("validating")
with torch.inference_mode():
self._policy.eval()
validation_dicts = []
for batch_idx, data in enumerate(self._val_dataloader):
forward_dict = self._forward_pass(data) # forward_dict = {"loss": loss, "kl": kl, "mse": mse}
validation_dicts.append(forward_dict)
if batch_idx > 50: # 限制验证批次数 TODO 确定批次关联
break
validation_summary = compute_dict_mean(validation_dicts)
epoch_val_loss = validation_summary["loss"]
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss
best_ckpt_info = (
step,
min_val_loss,
deepcopy(self._policy.serialize()),
)
# wandb记录验证结果
# for k in list(validation_summary.keys()):
# validation_summary[f"val_{k}"] = validation_summary.pop(k)
# wandb.log(validation_summary, step=step)
logging.info(f"Val loss: {epoch_val_loss:.5f}")
summary_string = " ".join(f"{k}: {v.item():.3f}" for k, v in validation_summary.items())
logging.info(summary_string)
# 评估模型
# if (step > 0) and (step % self._eval_every == 0):
# ckpt_name = f"policy_step_{step}_seed_{self._seed}.ckpt"
# ckpt_path = os.path.join(self._ckpt_dir, ckpt_name)
# torch.save(self._policy.serialize(), ckpt_path)
# success, _ = self._eval.evaluate(ckpt_name)
# wandb.log({"success": success}, step=step)
# 训练模型
self._policy.train()
self._optimizer.zero_grad()
data = next(train_dataloader)
forward_dict = self._forward_pass(data)
loss = forward_dict["loss"]
loss.backward()
self._optimizer.step()
# wandb.log(forward_dict, step=step)
# 保存检查点
if step % self._save_every == 0:
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{step}_seed_{self._seed}.ckpt")
torch.save(self._policy.serialize(), ckpt_path)
# 保存最后的模型
ckpt_path = os.path.join(self._ckpt_dir, "policy_last.ckpt")
torch.save(self._policy.serialize(), ckpt_path)
best_step, min_val_loss, best_state_dict = best_ckpt_info
ckpt_path = os.path.join(self._ckpt_dir, f"policy_step_{best_step}_seed_{self._seed}.ckpt")
torch.save(best_state_dict, ckpt_path)
logging.info(f"Training finished:\nSeed {self._seed}, val loss {min_val_loss:.6f} at step {best_step}")
return best_ckpt_info
if __name__ == "__main__":
with open("/home/rm/aloha/shadow_rm_act/config/config.yaml") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
trainer = RmActTrainer(config)
trainer.train()

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,88 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import torch
from torchvision.ops.boxes import box_area
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=-1)
def box_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2,
(x1 - x0), (y1 - y0)]
return torch.stack(b, dim=-1)
# modified from torchvision to also return the union
def box_iou(boxes1, boxes2):
area1 = box_area(boxes1)
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / union
return iou, union
def generalized_box_iou(boxes1, boxes2):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
iou, union = box_iou(boxes1, boxes2)
lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
wh = (rb - lt).clamp(min=0) # [N,M,2]
area = wh[:, :, 0] * wh[:, :, 1]
return iou - (area - union) / area
def masks_to_boxes(masks):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if masks.numel() == 0:
return torch.zeros((0, 4), device=masks.device)
h, w = masks.shape[-2:]
y = torch.arange(0, h, dtype=torch.float)
x = torch.arange(0, w, dtype=torch.float)
y, x = torch.meshgrid(y, x)
x_mask = (masks * x.unsqueeze(0))
x_max = x_mask.flatten(1).max(-1)[0]
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
y_mask = (masks * y.unsqueeze(0))
y_max = y_mask.flatten(1).max(-1)[0]
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
return torch.stack([x_min, y_min, x_max, y_max], 1)

View File

@@ -0,0 +1,468 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import os
import subprocess
import time
from collections import defaultdict, deque
import datetime
import pickle
from packaging import version
from typing import Optional, List
import torch
import torch.distributed as dist
from torch import Tensor
# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if version.parse(torchvision.__version__) < version.parse('0.7'):
from torchvision.ops import _new_empty_tensor
from torchvision.ops.misc import _output_size
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
def all_gather(data):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size = get_world_size()
if world_size == 1:
return [data]
# serialized to a Tensor
buffer = pickle.dumps(data)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to("cuda")
# obtain Tensor size of each rank
local_size = torch.tensor([tensor.numel()], device="cuda")
size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
dist.all_gather(size_list, local_size)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list = []
for _ in size_list:
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
if local_size != max_size:
padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
tensor = torch.cat((tensor, padding), dim=0)
dist.all_gather(tensor_list, tensor)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def reduce_dict(input_dict, average=True):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.all_reduce(values)
if average:
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
if torch.cuda.is_available():
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}',
'max mem: {memory:.0f}'
])
else:
log_msg = self.delimiter.join([
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
])
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def get_sha():
cwd = os.path.dirname(os.path.abspath(__file__))
def _run(command):
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
sha = 'N/A'
diff = "clean"
branch = 'N/A'
try:
sha = _run(['git', 'rev-parse', 'HEAD'])
subprocess.check_output(['git', 'diff'], cwd=cwd)
diff = _run(['git', 'diff-index', 'HEAD'])
diff = "has uncommited changes" if diff else "clean"
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
except Exception:
pass
message = f"sha: {sha}, status: {diff}, branch: {branch}"
return message
def collate_fn(batch):
batch = list(zip(*batch))
batch[0] = nested_tensor_from_tensor_list(batch[0])
return tuple(batch)
def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int]
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[Tensor]):
self.tensors = tensors
self.mask = mask
def to(self, device):
# type: (Device) -> NestedTensor # noqa
cast_tensor = self.tensors.to(device)
mask = self.mask
if mask is not None:
assert mask is not None
cast_mask = mask.to(device)
else:
cast_mask = None
return NestedTensor(cast_tensor, cast_mask)
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
if torchvision._is_tracing():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return _onnx_nested_tensor_from_tensor_list(tensor_list)
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@torch.jit.unused
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
max_size = []
for i in range(tensor_list[0].dim()):
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
max_size.append(max_size_i)
max_size = tuple(max_size)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs = []
padded_masks = []
for img in tensor_list:
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
padded_imgs.append(padded_img)
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
padded_masks.append(padded_mask.to(torch.bool))
tensor = torch.stack(padded_imgs)
mask = torch.stack(padded_masks)
return NestedTensor(tensor, mask=mask)
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}'.format(
args.rank, args.dist_url), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
if target.numel() == 0:
return [torch.zeros([], device=output.device)]
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if version.parse(torchvision.__version__) < version.parse('0.7'):
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
output_shape = _output_size(2, input, size, scale_factor)
output_shape = list(input.shape[:-2]) + list(output_shape)
return _new_empty_tensor(input, output_shape)
else:
return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)

View File

@@ -0,0 +1,107 @@
"""
Plotting utilities to visualize training logs.
"""
import torch
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path, PurePath
def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
'''
Function to plot specific fields from training log(s). Plots both training and test results.
:: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
- fields = which results to plot from each log file - plots both training and test for each field.
- ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
- log_name = optional, name of log file if different than default 'log.txt'.
:: Outputs - matplotlib plots of results in fields, color coded for each log file.
- solid lines are training results, dashed lines are test results.
'''
func_name = "plot_utils.py::plot_logs"
# verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
# convert single Path to list to avoid 'not iterable' error
if not isinstance(logs, list):
if isinstance(logs, PurePath):
logs = [logs]
print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
else:
raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
Expect list[Path] or single Path obj, received {type(logs)}")
# Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
for i, dir in enumerate(logs):
if not isinstance(dir, PurePath):
raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
if not dir.exists():
raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
# verify log_name exists
fn = Path(dir / log_name)
if not fn.exists():
print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?")
print(f"--> full path of missing log file: {fn}")
return
# load log file(s) and plot
dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
for j, field in enumerate(fields):
if field == 'mAP':
coco_eval = pd.DataFrame(
np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
).ewm(com=ewm_col).mean()
axs[j].plot(coco_eval, c=color)
else:
df.interpolate().ewm(com=ewm_col).mean().plot(
y=[f'train_{field}', f'test_{field}'],
ax=axs[j],
color=[color] * 2,
style=['-', '--']
)
for ax, field in zip(axs, fields):
ax.legend([Path(p).name for p in logs])
ax.set_title(field)
def plot_precision_recall(files, naming_scheme='iter'):
if naming_scheme == 'exp_id':
# name becomes exp_id
names = [f.parts[-3] for f in files]
elif naming_scheme == 'iter':
names = [f.stem for f in files]
else:
raise ValueError(f'not supported {naming_scheme}')
fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
data = torch.load(f)
# precision is n_iou, n_points, n_cat, n_area, max_det
precision = data['precision']
recall = data['params'].recThrs
scores = data['scores']
# take precision for all classes, all areas and 100 detections
precision = precision[0, :, :, 0, -1].mean(1)
scores = scores[0, :, :, 0, -1].mean(1)
prec = precision.mean()
rec = data['recall'][0, :, 0, -1].mean()
print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
f'score={scores.mean():0.3f}, ' +
f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
)
axs[0].plot(recall, precision, c=color)
axs[1].plot(recall, scores, c=color)
axs[0].set_title('Precision / Recall')
axs[0].legend(names)
axs[1].set_title('Scores / Recall')
axs[1].legend(names)
return fig, axs

View File

@@ -0,0 +1,499 @@
import numpy as np
import torch
import os
import h5py
import pickle
import fnmatch
import cv2
from time import time
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as transforms
def flatten_list(l):
return [item for sublist in l for item in sublist]
class EpisodicDataset(torch.utils.data.Dataset):
def __init__(
self,
dataset_path_list,
camera_names,
norm_stats,
episode_ids,
episode_len,
chunk_size,
policy_class,
):
super(EpisodicDataset).__init__()
self.episode_ids = episode_ids
self.dataset_path_list = dataset_path_list
self.camera_names = camera_names
self.norm_stats = norm_stats
self.episode_len = episode_len
self.chunk_size = chunk_size
self.cumulative_len = np.cumsum(self.episode_len)
self.max_episode_len = max(episode_len)
self.policy_class = policy_class
if self.policy_class == "Diffusion":
self.augment_images = True
else:
self.augment_images = False
self.transformations = None
self.__getitem__(0) # initialize self.is_sim and self.transformations
self.is_sim = False
# def __len__(self):
# return sum(self.episode_len)
def _locate_transition(self, index):
assert index < self.cumulative_len[-1]
episode_index = np.argmax(
self.cumulative_len > index
) # argmax returns first True index
start_ts = index - (
self.cumulative_len[episode_index] - self.episode_len[episode_index]
)
episode_id = self.episode_ids[episode_index]
return episode_id, start_ts
def __getitem__(self, index):
episode_id, start_ts = self._locate_transition(index)
dataset_path = self.dataset_path_list[episode_id]
try:
# print(dataset_path)
with h5py.File(dataset_path, "r") as root:
try: # some legacy data does not have this attribute
is_sim = root.attrs["sim"]
except:
is_sim = False
compressed = root.attrs.get("compress", False)
if "/base_action" in root:
base_action = root["/base_action"][()]
base_action = preprocess_base_action(base_action)
action = np.concatenate([root["/action"][()], base_action], axis=-1)
else:
# TODO
action = root["/action"][()]
# dummy_base_action = np.zeros([action.shape[0], 2])
# action = np.concatenate([action, dummy_base_action], axis=-1)
original_action_shape = action.shape
episode_len = original_action_shape[0]
# get observation at start_ts only
qpos = root["/observations/qpos"][start_ts]
qvel = root["/observations/qvel"][start_ts]
image_dict = dict()
for cam_name in self.camera_names:
image_dict[cam_name] = root[f"/observations/images/{cam_name}"][
start_ts
]
if compressed:
for cam_name in image_dict.keys():
decompressed_image = cv2.imdecode(image_dict[cam_name], 1)
image_dict[cam_name] = np.array(decompressed_image)
# get all actions after and including start_ts
if is_sim:
action = action[start_ts:]
action_len = episode_len - start_ts
else:
action = action[
max(0, start_ts - 1) :
] # hack, to make timesteps more aligned
action_len = episode_len - max(
0, start_ts - 1
) # hack, to make timesteps more aligned
# self.is_sim = is_sim
padded_action = np.zeros(
(self.max_episode_len, original_action_shape[1]), dtype=np.float32
)
padded_action[:action_len] = action
is_pad = np.zeros(self.max_episode_len)
is_pad[action_len:] = 1
padded_action = padded_action[: self.chunk_size]
is_pad = is_pad[: self.chunk_size]
# new axis for different cameras
all_cam_images = []
for cam_name in self.camera_names:
all_cam_images.append(image_dict[cam_name])
all_cam_images = np.stack(all_cam_images, axis=0)
# construct observations
image_data = torch.from_numpy(all_cam_images)
qpos_data = torch.from_numpy(qpos).float()
action_data = torch.from_numpy(padded_action).float()
is_pad = torch.from_numpy(is_pad).bool()
# channel last
image_data = torch.einsum("k h w c -> k c h w", image_data)
# augmentation
if self.transformations is None:
print("Initializing transformations")
original_size = image_data.shape[2:]
ratio = 0.95
self.transformations = [
transforms.RandomCrop(
size=[
int(original_size[0] * ratio),
int(original_size[1] * ratio),
]
),
transforms.Resize(original_size, antialias=True),
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
transforms.ColorJitter(
brightness=0.3, contrast=0.4, saturation=0.5
), # , hue=0.08)
]
if self.augment_images:
for transform in self.transformations:
image_data = transform(image_data)
# normalize image and change dtype to float
image_data = image_data / 255.0
if self.policy_class == "Diffusion":
# normalize to [-1, 1]
action_data = (
(action_data - self.norm_stats["action_min"])
/ (self.norm_stats["action_max"] - self.norm_stats["action_min"])
) * 2 - 1
else:
# normalize to mean 0 std 1
action_data = (
action_data - self.norm_stats["action_mean"]
) / self.norm_stats["action_std"]
qpos_data = (qpos_data - self.norm_stats["qpos_mean"]) / self.norm_stats[
"qpos_std"
]
except:
print(f"Error loading {dataset_path} in __getitem__")
quit()
# print(image_data.dtype, qpos_data.dtype, action_data.dtype, is_pad.dtype)
return image_data, qpos_data, action_data, is_pad
def get_norm_stats(dataset_path_list):
all_qpos_data = []
all_action_data = []
all_episode_len = []
for dataset_path in dataset_path_list:
try:
with h5py.File(dataset_path, "r") as root:
qpos = root["/observations/qpos"][()]
qvel = root["/observations/qvel"][()]
if "/base_action" in root:
base_action = root["/base_action"][()]
# base_action = preprocess_base_action(base_action)
# action = np.concatenate([root["/action"][()], base_action], axis=-1)
else:
# TODO
action = root["/action"][()]
# dummy_base_action = np.zeros([action.shape[0], 2])
# action = np.concatenate([action, dummy_base_action], axis=-1)
except Exception as e:
print(f"Error loading {dataset_path} in get_norm_stats")
print(e)
quit()
all_qpos_data.append(torch.from_numpy(qpos))
all_action_data.append(torch.from_numpy(action))
all_episode_len.append(len(qpos))
all_qpos_data = torch.cat(all_qpos_data, dim=0)
all_action_data = torch.cat(all_action_data, dim=0)
# normalize action data
action_mean = all_action_data.mean(dim=[0]).float()
action_std = all_action_data.std(dim=[0]).float()
action_std = torch.clip(action_std, 1e-2, np.inf) # clipping
# normalize qpos data
qpos_mean = all_qpos_data.mean(dim=[0]).float()
qpos_std = all_qpos_data.std(dim=[0]).float()
qpos_std = torch.clip(qpos_std, 1e-2, np.inf) # clipping
action_min = all_action_data.min(dim=0).values.float()
action_max = all_action_data.max(dim=0).values.float()
eps = 0.0001
stats = {
"action_mean": action_mean.numpy(),
"action_std": action_std.numpy(),
"action_min": action_min.numpy() - eps,
"action_max": action_max.numpy() + eps,
"qpos_mean": qpos_mean.numpy(),
"qpos_std": qpos_std.numpy(),
"example_qpos": qpos,
}
return stats, all_episode_len
def find_all_hdf5(dataset_dir, skip_mirrored_data):
hdf5_files = []
for root, dirs, files in os.walk(dataset_dir):
for filename in fnmatch.filter(files, "*.hdf5"):
if "features" in filename:
continue
if skip_mirrored_data and "mirror" in filename:
continue
hdf5_files.append(os.path.join(root, filename))
print(f"Found {len(hdf5_files)} hdf5 files")
return hdf5_files
def BatchSampler(batch_size, episode_len_l, sample_weights):
sample_probs = (
np.array(sample_weights) / np.sum(sample_weights)
if sample_weights is not None
else None
)
# print("BatchSampler", sample_weights)
sum_dataset_len_l = np.cumsum(
[0] + [np.sum(episode_len) for episode_len in episode_len_l]
)
while True:
batch = []
for _ in range(batch_size):
episode_idx = np.random.choice(len(episode_len_l), p=sample_probs)
step_idx = np.random.randint(
sum_dataset_len_l[episode_idx], sum_dataset_len_l[episode_idx + 1]
)
batch.append(step_idx)
yield batch
def load_data(
dataset_dir_l,
name_filter,
camera_names,
batch_size_train,
batch_size_val,
chunk_size,
skip_mirrored_data=False,
load_pretrain=False,
policy_class=None,
stats_dir_l=None,
sample_weights=None,
train_ratio=0.99,
):
if type(dataset_dir_l) == str:
dataset_dir_l = [dataset_dir_l]
dataset_path_list_list = [
find_all_hdf5(dataset_dir, skip_mirrored_data) for dataset_dir in dataset_dir_l
]
num_episodes_0 = len(dataset_path_list_list[0])
dataset_path_list = flatten_list(dataset_path_list_list)
dataset_path_list = [n for n in dataset_path_list if name_filter(n)]
num_episodes_l = [
len(dataset_path_list) for dataset_path_list in dataset_path_list_list
]
num_episodes_cumsum = np.cumsum(num_episodes_l)
# obtain train test split on dataset_dir_l[0]
shuffled_episode_ids_0 = np.random.permutation(num_episodes_0)
train_episode_ids_0 = shuffled_episode_ids_0[: int(train_ratio * num_episodes_0)]
val_episode_ids_0 = shuffled_episode_ids_0[int(train_ratio * num_episodes_0) :]
train_episode_ids_l = [train_episode_ids_0] + [
np.arange(num_episodes) + num_episodes_cumsum[idx]
for idx, num_episodes in enumerate(num_episodes_l[1:])
]
val_episode_ids_l = [val_episode_ids_0]
train_episode_ids = np.concatenate(train_episode_ids_l)
val_episode_ids = np.concatenate(val_episode_ids_l)
print(
f"\n\nData from: {dataset_dir_l}\n- Train on {[len(x) for x in train_episode_ids_l]} episodes\n- Test on {[len(x) for x in val_episode_ids_l]} episodes\n\n"
)
# obtain normalization stats for qpos and action
# if load_pretrain:
# with open(os.path.join('/home/zfu/interbotix_ws/src/act/ckpts/pretrain_all', 'dataset_stats.pkl'), 'rb') as f:
# norm_stats = pickle.load(f)
# print('Loaded pretrain dataset stats')
_, all_episode_len = get_norm_stats(dataset_path_list)
train_episode_len_l = [
[all_episode_len[i] for i in train_episode_ids]
for train_episode_ids in train_episode_ids_l
]
val_episode_len_l = [
[all_episode_len[i] for i in val_episode_ids]
for val_episode_ids in val_episode_ids_l
]
train_episode_len = flatten_list(train_episode_len_l)
val_episode_len = flatten_list(val_episode_len_l)
if stats_dir_l is None:
stats_dir_l = dataset_dir_l
elif type(stats_dir_l) == str:
stats_dir_l = [stats_dir_l]
norm_stats, _ = get_norm_stats(
flatten_list(
[find_all_hdf5(stats_dir, skip_mirrored_data) for stats_dir in stats_dir_l]
)
)
print(f"Norm stats from: {stats_dir_l}")
batch_sampler_train = BatchSampler(
batch_size_train, train_episode_len_l, sample_weights
)
batch_sampler_val = BatchSampler(batch_size_val, val_episode_len_l, None)
# print(f'train_episode_len: {train_episode_len}, val_episode_len: {val_episode_len}, train_episode_ids: {train_episode_ids}, val_episode_ids: {val_episode_ids}')
# construct dataset and dataloader
train_dataset = EpisodicDataset(
dataset_path_list,
camera_names,
norm_stats,
train_episode_ids,
train_episode_len,
chunk_size,
policy_class,
)
val_dataset = EpisodicDataset(
dataset_path_list,
camera_names,
norm_stats,
val_episode_ids,
val_episode_len,
chunk_size,
policy_class,
)
train_num_workers = (
(8 if os.getlogin() == "zfu" else 16) if train_dataset.augment_images else 2
)
val_num_workers = 8 if train_dataset.augment_images else 2
print(
f"Augment images: {train_dataset.augment_images}, train_num_workers: {train_num_workers}, val_num_workers: {val_num_workers}"
)
train_dataloader = DataLoader(
train_dataset,
batch_sampler=batch_sampler_train,
pin_memory=True,
num_workers=train_num_workers,
prefetch_factor=2,
)
val_dataloader = DataLoader(
val_dataset,
batch_sampler=batch_sampler_val,
pin_memory=True,
num_workers=val_num_workers,
prefetch_factor=2,
)
return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim
def calibrate_linear_vel(base_action, c=None):
if c is None:
c = 0.0 # 0.19
v = base_action[..., 0]
w = base_action[..., 1]
base_action = base_action.copy()
base_action[..., 0] = v - c * w
return base_action
def smooth_base_action(base_action):
return np.stack(
[
np.convolve(base_action[:, i], np.ones(5) / 5, mode="same")
for i in range(base_action.shape[1])
],
axis=-1,
).astype(np.float32)
def preprocess_base_action(base_action):
# base_action = calibrate_linear_vel(base_action)
base_action = smooth_base_action(base_action)
return base_action
def postprocess_base_action(base_action):
linear_vel, angular_vel = base_action
linear_vel *= 1.0
angular_vel *= 1.0
# angular_vel = 0
# if np.abs(linear_vel) < 0.05:
# linear_vel = 0
return np.array([linear_vel, angular_vel])
### env utils
def sample_box_pose():
x_range = [0.0, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
cube_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
cube_quat = np.array([1, 0, 0, 0])
return np.concatenate([cube_position, cube_quat])
def sample_insertion_pose():
# Peg
x_range = [0.1, 0.2]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
peg_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
peg_quat = np.array([1, 0, 0, 0])
peg_pose = np.concatenate([peg_position, peg_quat])
# Socket
x_range = [-0.2, -0.1]
y_range = [0.4, 0.6]
z_range = [0.05, 0.05]
ranges = np.vstack([x_range, y_range, z_range])
socket_position = np.random.uniform(ranges[:, 0], ranges[:, 1])
socket_quat = np.array([1, 0, 0, 0])
socket_pose = np.concatenate([socket_position, socket_quat])
return peg_pose, socket_pose
### helper functions
def compute_dict_mean(epoch_dicts):
result = {k: None for k in epoch_dicts[0]}
num_items = len(epoch_dicts)
for k in result:
value_sum = 0
for epoch_dict in epoch_dicts:
value_sum += epoch_dict[k]
result[k] = value_sum / num_items
return result
def detach_dict(d):
new_d = dict()
for k, v in d.items():
new_d[k] = v.detach()
return new_d
def set_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)

View File

@@ -0,0 +1,163 @@
from shadow_camera.realsense import RealSenseCamera
from shadow_rm_robot.realman_arm import RmArm
import yaml
import time
import multiprocessing
import numpy as np
import collections
import logging
import dm_env
import tracemalloc
class DeviceAloha:
def __init__(self, aloha_config):
"""
初始化设备
Args:
device_name (str): 设备名称
"""
config_left_arm = aloha_config["rm_left_arm"]
config_right_arm = aloha_config["rm_right_arm"]
config_head_camera = aloha_config["head_camera"]
config_bottom_camera = aloha_config["bottom_camera"]
config_left_camera = aloha_config["left_camera"]
config_right_camera = aloha_config["right_camera"]
self.init_left_arm_angle = aloha_config["init_left_arm_angle"]
self.init_right_arm_angle = aloha_config["init_right_arm_angle"]
self.arm_left = RmArm(config_left_arm)
self.arm_right = RmArm(config_right_arm)
self.camera_left = RealSenseCamera(config_head_camera, False)
self.camera_right = RealSenseCamera(config_bottom_camera, False)
self.camera_bottom = RealSenseCamera(config_left_camera, False)
self.camera_top = RealSenseCamera(config_right_camera, False)
self.camera_left.start_camera()
self.camera_right.start_camera()
self.camera_bottom.start_camera()
self.camera_top.start_camera()
def close(self):
"""
关闭摄像头
"""
self.camera_left.close()
self.camera_right.close()
self.camera_bottom.close()
self.camera_top.close()
def get_qps(self):
"""
获取关节角度
Returns:
np.array: 关节角度
"""
left_slave_arm_angle = self.arm_left.get_joint_angle()
left_joint_angles_array = np.array(list(left_slave_arm_angle.values()))
right_slave_arm_angle = self.arm_right.get_joint_angle()
right_joint_angles_array = np.array(list(right_slave_arm_angle.values()))
return np.concatenate([left_joint_angles_array, right_joint_angles_array])
def get_qvel(self):
"""
获取关节速度
Returns:
np.array: 关节速度
"""
left_slave_arm_velocity = self.arm_left.get_joint_velocity()
left_joint_velocity_array = np.array(list(left_slave_arm_velocity.values()))
right_slave_arm_velocity = self.arm_right.get_joint_velocity()
right_joint_velocity_array = np.array(list(right_slave_arm_velocity.values()))
return np.concatenate([left_joint_velocity_array, right_joint_velocity_array])
def get_effort(self):
"""
获取关节力
Returns:
np.array: 关节力
"""
left_slave_arm_effort = self.arm_left.get_joint_effort()
left_joint_effort_array = np.array(list(left_slave_arm_effort.values()))
right_slave_arm_effort = self.arm_right.get_joint_effort()
right_joint_effort_array = np.array(list(right_slave_arm_effort.values()))
return np.concatenate([left_joint_effort_array, right_joint_effort_array])
def get_images(self):
"""
获取图像
Returns:
dict: 图像字典
"""
top_image, _, _, _ = self.camera_top.read_frame(True, False, False, False)
bottom_image, _, _, _ = self.camera_bottom.read_frame(True, False, False, False)
left_image, _, _, _ = self.camera_left.read_frame(True, False, False, False)
right_image, _, _, _ = self.camera_right.read_frame(True, False, False, False)
return {
"cam_high": top_image,
"cam_low": bottom_image,
"cam_left": left_image,
"cam_right": right_image,
}
def get_observation(self):
obs = collections.OrderedDict()
obs["qpos"] = self.get_qps()
obs["qvel"] = self.get_qvel()
obs["effort"] = self.get_effort()
obs["images"] = self.get_images()
# self.get_images()
return obs
def reset(self):
logging.info("Resetting the environment")
_ = self.arm_left.set_joint_position(self.init_left_arm_angle[0:6])
_ = self.arm_right.set_joint_position(self.init_right_arm_angle[0:6])
self.arm_left.set_gripper_position(0)
self.arm_right.set_gripper_position(0)
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=0,
discount=None,
observation=self.get_observation(),
)
def step(self, target_angle):
self.arm_left.set_joint_canfd_position(target_angle[0:6])
self.arm_right.set_joint_canfd_position(target_angle[7:13])
self.arm_left.set_gripper_position(target_angle[6])
self.arm_right.set_gripper_position(target_angle[13])
return dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=0,
discount=None,
observation=self.get_observation(),
)
if __name__ == '__main__':
with open("/home/rm/code/shadow_act/config/config.yaml", "r") as f:
config = yaml.safe_load(f)
aloha_config = config["robot_env"]
device = DeviceAloha(aloha_config)
device.reset()
image_list = []
tager_angle = np.concatenate([device.init_left_arm_angle, device.init_right_arm_angle])
while True:
tracemalloc.start() # 启动内存跟踪
tager_angle = np.array([angle + 0.1 if i not in [6, 13] else angle for i, angle in enumerate(tager_angle)])
time_step = time.time()
timestep = device.step(tager_angle)
logging.info(f"Time: {time.time() - time_step}")
image_list.append(timestep.observation["images"])
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics('lineno')
# del timestep
print("[ Top 10 ]")
for stat in top_stats[:10]:
print(stat)
# logging.info(f"Images: {obs}")

View File

@@ -0,0 +1,32 @@
import os
# import time
import yaml
import torch
import pickle
import dm_env
import logging
import collections
import numpy as np
from einops import rearrange
import matplotlib.pyplot as plt
from torchvision import transforms
from shadow_rm_robot.realman_arm import RmArm
from shadow_camera.realsense import RealSenseCamera
from shadow_act.models.latent_model import Latent_Model_Transformer
from shadow_act.network.policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
from shadow_act.utils.utils import (
load_data,
sample_box_pose,
sample_insertion_pose,
compute_dict_mean,
set_seed,
detach_dict,
)
logging.basicConfig(level=logging.DEBUG)
print('daasdas')

View File

@@ -0,0 +1,147 @@
import os
import numpy as np
import cv2
import h5py
import argparse
import matplotlib.pyplot as plt
from constants import DT
import IPython
e = IPython.embed
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
STATE_NAMES = JOINT_NAMES + ["gripper"]
def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
if not os.path.isfile(dataset_path):
print(f'Dataset does not exist at \n{dataset_path}\n')
exit()
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
action = root['/action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
return qpos, qvel, action, image_dict
def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}'
qpos, qvel, action, image_dict = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back
def save_videos(video, dt, video_path=None):
if isinstance(video, list):
cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names)
fps = int(1/dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for ts, image_dict in enumerate(video):
images = []
for cam_name in cam_names:
image = image_dict[cam_name]
image = image[:, :, [2, 1, 0]] # swap B and R channel
images.append(image)
images = np.concatenate(images, axis=1)
out.write(images)
out.release()
print(f'Saved video to: {video_path}')
elif isinstance(video, dict):
cam_names = list(video.keys())
all_cam_videos = []
for cam_name in cam_names:
all_cam_videos.append(video[cam_name])
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
image = all_cam_videos[t]
image = image[:, :, [2, 1, 0]] # swap B and R channel
out.write(image)
out.release()
print(f'Saved video to: {video_path}')
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
if label_overwrite:
label1, label2 = label_overwrite
else:
label1, label2 = 'State', 'Command'
qpos = np.array(qpos_list) # ts, dim
command = np.array(command_list)
num_ts, num_dim = qpos.shape
h, w = 2, num_dim
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
# plot joint state
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(qpos[:, dim_idx], label=label1)
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
# plot arm command
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(command[:, dim_idx], label=label2)
ax.legend()
if ylim:
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved qpos plot to: {plot_path}')
plt.close()
def visualize_timestamp(t_list, dataset_path):
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
h, w = 4, 10
fig, axs = plt.subplots(2, 1, figsize=(w, h*2))
# process t_list
t_float = []
for secs, nsecs in t_list:
t_float.append(secs + nsecs * 10E-10)
t_float = np.array(t_float)
ax = axs[0]
ax.plot(np.arange(len(t_float)), t_float)
ax.set_title(f'Camera frame timestamps')
ax.set_xlabel('timestep')
ax.set_ylabel('time (sec)')
ax = axs[1]
ax.plot(np.arange(len(t_float)-1), t_float[:-1] - t_float[1:])
ax.set_title(f'dt')
ax.set_xlabel('timestep')
ax.set_ylabel('time (sec)')
plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved timestamp plot to: {plot_path}')
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
main(vars(parser.parse_args()))

View File

@@ -0,0 +1,10 @@
__pycache__/
build/
devel/
dist/
data/
.catkin_workspace
*.pyc
*.pyo
*.pt
.vscode/

View File

@@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

View File

@@ -0,0 +1 @@
aloha_data_synchronizer.py

View File

@@ -0,0 +1,17 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="4">
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
<item index="1" class="java.lang.String" itemvalue="thop" />
<item index="2" class="java.lang.String" itemvalue="torch" />
<item index="3" class="java.lang.String" itemvalue="torchvision" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.11 (随箱软件)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
</project>

View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
</modules>
</component>
</project>

View File

@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@@ -0,0 +1,38 @@
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
dataset_name: 'episode'
max_timesteps: 500
state_dim: 14
overwrite: False
arm_axis: 6
camera_names:
- 'cam_high'
- 'cam_low'
- 'cam_left'
- 'cam_right'
ros_topics:
camera_left: '/camera_left/rgb/image_raw'
camera_right: '/camera_right/rgb/image_raw'
camera_bottom: '/camera_bottom/rgb/image_raw'
camera_head: '/camera_head/rgb/image_raw'
left_master_arm: '/left_master_arm_joint_states'
left_slave_arm: '/left_slave_arm_joint_states'
right_master_arm: '/right_master_arm_joint_states'
right_slave_arm: '/right_slave_arm_joint_states'
left_aloha_state: '/left_slave_arm_aloha_state'
right_aloha_state: '/right_slave_arm_aloha_state'
robot_env: {
# TODO change the path to the correct one
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
arm_axis: 6,
head_camera: '241122071186',
bottom_camera: '152122078546',
left_camera: '150622070125',
right_camera: '151222072576',
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
}

View File

@@ -0,0 +1,9 @@
arm_ip: "192.168.1.18"
arm_port: 8080
arm_axis: 6
local_ip: "192.168.1.101"
local_port: 8089
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
get_vel: True
get_torque: True

View File

@@ -0,0 +1,9 @@
arm_ip: "192.168.1.19"
arm_port: 8080
arm_axis: 6
local_ip: "192.168.1.101"
local_port: 8090
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
get_vel: True
get_torque: True

View File

@@ -0,0 +1,4 @@
port: /dev/ttyUSB1
baudrate: 460800
hex_data: "55 AA 02 00 00 67"
arm_axis: 6

View File

@@ -0,0 +1,4 @@
port: /dev/ttyUSB0
baudrate: 460800
hex_data: "55 AA 02 00 00 67"
arm_axis: 6

View File

@@ -0,0 +1,6 @@
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
dataset_name: 'episode'
episode_idx: 1
FPS: 30
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints

View File

@@ -0,0 +1,39 @@
[tool.poetry]
name = "shadow_rm_aloha"
version = "0.1.1"
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
readme = "README.md"
authors = ["Shadow <qiuchengzhan@gmail.com>"]
license = "MIT"
# include = ["realman_vision/pytransform/_pytransform.so",]
classifiers = [
"Operating System :: POSIX :: Linux amd64",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = ">=3.10"
matplotlib = ">=3.9.2"
h5py = ">=3.12.1"
# rospy = ">=1.17.0"
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
pytest = ">=8.3"
black = ">=24.10.0"
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
[tool.poetry.group.dev.dependencies]
[build-system]
requires = ["poetry-core>=1.8.4"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.0.2)
project(shadow_rm_aloha)
find_package(catkin REQUIRED COMPONENTS
rospy
sensor_msgs
cv_bridge
image_transport
std_msgs
message_generation
)
add_service_files(
FILES
GetArmStatus.srv
GetImage.srv
MoveArm.srv
)
generate_messages(
DEPENDENCIES
sensor_msgs
std_msgs
)
catkin_package(
CATKIN_DEPENDS message_runtime rospy std_msgs
)
include_directories(
${catkin_INCLUDE_DIRS}
)
install(PROGRAMS
arm_node/slave_arm_publisher.py
arm_node/master_arm_publisher.py
arm_node/slave_arm_pub_sub.py
camera_node/camera_publisher.py
data_sub_process/aloha_data_synchronizer.py
data_sub_process/aloha_data_collect.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
)

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
import rospy
import logging
from shadow_rm_robot.servo_robotic_arm import ServoArm
from sensor_msgs.msg import JointState
# 配置日志记录
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class MasterArmPublisher:
def __init__(self):
rospy.init_node("master_arm_publisher", anonymous=True)
arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml")
hz = rospy.get_param("~hz", 250)
self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
self.arm = ServoArm(arm_config)
self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1)
self.rate = rospy.Rate(hz) # 30 Hz
def publish_joint_states(self):
while not rospy.is_shutdown():
joint_state = JointState()
joint_pos = self.arm.get_joint_actions()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = list(joint_pos.keys())
joint_state.position = list(joint_pos.values())
joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改)
joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改)
# rospy.loginfo(f"{self.joint_states_topic}: {joint_state}")
self.publisher.publish(joint_state)
self.rate.sleep()
if __name__ == "__main__":
try:
arm_publisher = MasterArmPublisher()
arm_publisher.publish_joint_states()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
import rospy
import logging
from shadow_rm_robot.realman_arm import RmArm
from sensor_msgs.msg import JointState
# 配置日志记录
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class SlaveArmPublisher:
def __init__(self):
rospy.init_node("slave_arm_publisher", anonymous=True)
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
hz = rospy.get_param("~hz", 250)
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions")
self.arm = RmArm(arm_config)
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback)
self.rate = rospy.Rate(hz)
def publish_joint_states(self):
while not rospy.is_shutdown():
joint_state = JointState()
data = self.arm.get_integrate_data()
# data = self.arm.get_arm_data()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
# joint_state.position = data["joint_angle"]
joint_state.position = data['arm_angle']
# joint_state.position = list(data["arm_angle"])
# joint_state.velocity = list(data["arm_velocity"])
# joint_state.effort = list(data["arm_torque"])
# rospy.loginfo(f"joint_states_topic: {joint_state}")
self.publisher.publish(joint_state)
self.rate.sleep()
def callback(self, data):
# rospy.loginfo(f"Received joint_states_topic: {data}")
start_time = rospy.Time.now()
if data is None:
return
if data.name == ["joint_canfd"]:
self.arm.set_joint_canfd_position(data.position[0:6])
elif data.name == ["joint_j"]:
self.arm.set_joint_position(data.position[0:6])
# self.arm.set_gripper_position(data.position[6])
end_time = rospy.Time.now()
time_cost_ms = (end_time - start_time).to_sec() * 1000
rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}")
if __name__ == "__main__":
try:
arm_publisher = SlaveArmPublisher()
arm_publisher.publish_joint_states()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
import rospy
import logging
from shadow_rm_robot.realman_arm import RmArm
from sensor_msgs.msg import JointState
from std_msgs.msg import Int32MultiArray
# 配置日志记录
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class SlaveArmPublisher:
def __init__(self):
rospy.init_node("slave_arm_publisher", anonymous=True)
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
hz = rospy.get_param("~hz", 250)
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state")
self.arm = RmArm(arm_config)
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1)
self.rate = rospy.Rate(hz)
def publish_joint_states(self):
while not rospy.is_shutdown():
joint_state = JointState()
data = self.arm.get_integrate_data()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
joint_state.position = list(data["arm_angle"])
joint_state.velocity = list(data["arm_velocity"])
joint_state.effort = list(data["arm_torque"])
# rospy.loginfo(f"joint_states_topic: {joint_state}")
self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values()))
self.publisher.publish(joint_state)
self.rate.sleep()
if __name__ == "__main__":
try:
arm_publisher = SlaveArmPublisher()
arm_publisher.publish_joint_states()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import rospy
from sensor_msgs.msg import Image
from std_msgs.msg import Header
import numpy as np
from shadow_camera.realsense import RealSenseCamera
class CameraPublisher:
def __init__(self):
rospy.init_node('camera_publisher', anonymous=True)
self.serial_number = rospy.get_param('~serial_number', None)
hz = rospy.get_param('~hz', 30)
rospy.loginfo(f"Serial number: {self.serial_number}")
self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw')
self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw')
rospy.loginfo(f"RGB topic: {self.rgb_topic}")
rospy.loginfo(f"Depth topic: {self.depth_topic}")
self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10)
# self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10)
self.rate = rospy.Rate(hz) # 30 Hz
self.camera = RealSenseCamera(self.serial_number, False)
rospy.loginfo("Camera initialized")
def publish_images(self):
self.camera.start_camera()
rospy.loginfo("Camera started")
while not rospy.is_shutdown():
result = self.camera.read_frame(True, False, False, False)
if result is None:
rospy.logerr("Failed to read frame from camera")
continue
color_image, depth_image, _, _ = result
if color_image is not None or depth_image is not None:
rgb_msg = self.create_image_msg(color_image, "bgr8")
# depth_msg = self.create_image_msg(depth_image, "mono16")
self.rgb_pub.publish(rgb_msg)
# self.depth_pub.publish(depth_msg)
# rospy.loginfo("Published RGB image")
else:
rospy.logwarn("Received None for color_image or depth_image")
self.rate.sleep()
self.camera.stop_camera()
rospy.loginfo("Camera stopped")
def create_image_msg(self, image, encoding):
msg = Image()
msg.header = Header()
msg.header.stamp = rospy.Time.now()
msg.height, msg.width = image.shape[:2]
msg.encoding = encoding
msg.is_bigendian = False
msg.step = image.strides[0]
msg.data = np.array(image).tobytes()
return msg
if __name__ == '__main__':
try:
camera_publisher = CameraPublisher()
camera_publisher.publish_images()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,112 @@
import os
import time
import h5py
import logging
from datetime import datetime
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class DataCollector:
def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False):
self.arm_axis = 7
self.dataset_dir = dataset_dir
self.dataset_name = dataset_name
self.max_timesteps = max_timesteps
self.camera_names = camera_names
self.state_dim = state_dim
self.overwrite = overwrite
self.data_dict = {
'/observations/qpos': [],
'/observations/qvel': [],
'/observations/effort': [],
'/action': [],
}
for cam_name in camera_names:
self.data_dict[f'/observations/images/{cam_name}'] = []
# 自动检测和创建数据集目录
self.create_dataset_dir()
self.timesteps_collected = 0
def create_dataset_dir(self):
# 按照年月日创建目录
date_str = datetime.now().strftime("%Y%m%d")
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
def collect_data(self, ts, action):
self.data_dict['/observations/qpos'].append(ts.observation['qpos'])
self.data_dict['/observations/qvel'].append(ts.observation['qvel'])
self.data_dict['/observations/effort'].append(ts.observation['effort'])
self.data_dict['/action'].append(action)
for cam_name in self.camera_names:
self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
def save_data(self):
t0 = time.time()
# 保存数据
with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root:
root.attrs['sim'] = False
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in self.camera_names:
_ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3))
_ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim))
_ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim))
_ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim))
_ = root.create_dataset('action', (self.max_timesteps, self.state_dim))
for name, array in self.data_dict.items():
root[name][...] = array
print(f'Saving: {time.time() - t0:.1f} secs')
return True
def load_hdf5(self, orign_path, file):
self.dataset_path = os.path.join(self.dataset_dir, file)
if not os.path.isfile(orign_path):
logging.error(f'Dataset does not exist at {orign_path}')
exit()
with h5py.File(orign_path, 'r') as root:
self.is_sim = root.attrs['sim']
self.qpos = root['/observations/qpos'][()]
self.qvel = root['/observations/qvel'][()]
self.effort = root['/observations/effort'][()]
self.action = root['/action'][()]
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
for cam_name in root[f'/observations/images/'].keys()}
self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis]
self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1]
self.data_dict['/observations/qpos'] = self.qpos
self.data_dict['/observations/qvel'] = self.qvel
self.data_dict['/observations/effort'] = self.effort
self.data_dict['/action'] = self.action
for cam_name in self.camera_names:
self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name]
return True
if __name__ == '__main__':
"""
用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据
"""
dataset_dir = '/home/wang/project/shadow_rm_aloha/data'
orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128'
dataset_name = 'test'
max_timesteps = 300
camera_names = ['cam_high','cam_low','cam_left','cam_right']
state_dim = 16
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
for file in os.listdir(orign_dir):
collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
orign_path = os.path.join(orign_dir, file)
print(orign_path)
collector.load_hdf5(orign_path, file)
collector.save_data()

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
import os
import numpy as np
import h5py
import yaml
import logging
import time
from shadow_rm_robot.realman_arm import RmArm
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
class DataValidator:
def __init__(self, config):
self.dataset_dir = config['dataset_dir']
self.episode_idx = config['episode_idx']
self.joint_names = config['joint_names']
self.dataset_name = f'episode_{self.episode_idx}'
self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5')
self.state_names = self.joint_names + ["gripper"]
self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml')
def load_hdf5(self):
if not os.path.isfile(self.dataset_path):
logging.error(f'Dataset does not exist at {self.dataset_path}')
exit()
with h5py.File(self.dataset_path, 'r') as root:
self.is_sim = root.attrs['sim']
self.qpos = root['/observations/qpos'][()]
# self.qvel = root['/observations/qvel'][()]
# self.effort = root['/observations/effort'][()]
self.action = root['/action'][()]
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
for cam_name in root[f'/observations/images/'].keys()}
def validate_data(self):
# 验证位置数据
if not self.qpos.shape[1] == 14:
logging.error('qpos shape does not match expected number of joints')
return False
logging.info('Data validation passed')
return True
def control_robot(self):
self.arm.set_joint_position(self.qpos[0][0:6])
for qpos in self.qpos:
logging.info(f'qpos: {qpos}')
self.arm.set_joint_canfd_position(qpos[7:13])
self.arm.set_gripper_position(qpos[13])
time.sleep(0.035)
def run(self):
self.load_hdf5()
if self.validate_data():
self.control_robot()
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
if __name__ == '__main__':
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
validator = DataValidator(config)
validator.run()

View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python3
import os
import numpy as np
import cv2
import h5py
import yaml
import logging
import matplotlib.pyplot as plt
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class DataVisualizer:
def __init__(self, config):
self.dataset_dir = config['dataset_dir']
self.episode_idx = config['episode_idx']
self.dt = 1/config['FPS']
self.joint_names = config['joint_names']
self.state_names = self.joint_names + ["gripper"]
# self.camera_names = config['camera_names']
def join_file_path(self, file_name):
self.dataset_path = os.path.join(self.dataset_dir, file_name)
def load_hdf5(self):
if not os.path.isfile(self.dataset_path):
logging.error(f'Dataset does not exist at {self.dataset_path}')
exit()
with h5py.File(self.dataset_path, 'r') as root:
self.is_sim = root.attrs['sim']
self.qpos = root['/observations/qpos'][()]
# self.qvel = root['/observations/qvel'][()]
# self.effort = root['/observations/effort'][()]
self.action = root['/action'][()]
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
for cam_name in root[f'/observations/images/'].keys()}
def save_videos(self, video, dt, video_path=None):
if isinstance(video, list):
cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names)
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for image_dict in video:
images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names]
out.write(np.concatenate(images, axis=1))
out.release()
logging.info(f'Saved video to: {video_path}')
elif isinstance(video, dict):
cam_names = list(video.keys())
all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2)
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
out.write(all_cam_videos[t][:, :, [2, 1, 0]])
out.release()
logging.info(f'Saved video to: {video_path}')
def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None):
label1, label2 = label_overwrite if label_overwrite else ('State', 'Command')
qpos = np.array(qpos_list)
command = np.array(command_list)
num_ts, num_dim = qpos.shape
logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}')
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(qpos[:, dim_idx], label=label1)
ax.plot(command[:, dim_idx], label=label2)
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
if ylim:
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
logging.info(f'Saved qpos plot to: {plot_path}')
plt.close()
def visualize_single(self, data_list, label, plot_path, ylim=None):
data = np.array(data_list)
num_ts, num_dim = data.shape
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(data[:, dim_idx], label=label)
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
if ylim:
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
logging.info(f'Saved {label} plot to: {plot_path}')
plt.close()
def visualize_timestamp(self, t_list):
plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png')
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list])
axs[0].plot(np.arange(len(t_float)), t_float)
axs[0].set_title('Camera frame timestamps')
axs[0].set_xlabel('timestep')
axs[0].set_ylabel('time (sec)')
axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:])
axs[1].set_title('dt')
axs[1].set_xlabel('timestep')
axs[1].set_ylabel('time (sec)')
plt.tight_layout()
plt.savefig(plot_path)
logging.info(f'Saved timestamp plot to: {plot_path}')
plt.close()
def run(self):
for file_name in os.listdir(self.dataset_dir):
if file_name.endswith('.hdf5'):
self.join_file_path(file_name)
self.load_hdf5()
video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4'))
self.save_videos(self.image_dict, self.dt, video_path)
qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png'))
self.visualize_joints(self.qpos, self.action, qpos_plot_path)
# effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png'))
# self.visualize_single(self.effort, 'effort', effort_plot_path)
# error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png'))
# self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path)
# self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
if __name__ == '__main__':
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
visualizer = DataVisualizer(config)
visualizer.run()

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env python3
import time
import yaml
import rospy
import dm_env
import numpy as np
import collections
from datetime import datetime
from sensor_msgs.msg import Image, JointState
from shadow_rm_robot.realman_arm import RmArm
from message_filters import Subscriber, ApproximateTimeSynchronizer
class DataSynchronizer:
def __init__(self, config_path="config"):
rospy.init_node("synchronizer", anonymous=True)
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"]
self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"]
self.arm_axis = config["robot_env"]["arm_axis"]
self.camera_names = config["camera_names"]
# 创建订阅者
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
self.camera_bottom_sub = Subscriber(
config["ros_topics"]["camera_bottom"], Image
)
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
self.left_slave_arm_sub = Subscriber(
config["ros_topics"]["left_slave_arm_sub"], JointState
)
self.right_slave_arm_sub = Subscriber(
config["ros_topics"]["right_slave_arm_sub"], JointState
)
self.left_slave_arm_pub = rospy.Publisher(
config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1
)
self.right_slave_arm_pub = rospy.Publisher(
config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1
)
# 创建同步器
self.ats = ApproximateTimeSynchronizer(
[
self.camera_left_sub,
self.camera_right_sub,
self.camera_bottom_sub,
self.camera_head_sub,
self.left_slave_arm_sub,
self.right_slave_arm_sub,
],
queue_size=1,
slop=0.1,
)
self.ats.registerCallback(self.callback)
self.ts = None
self.is_frist_step = True
def callback(
self,
camera_left_img,
camera_right_img,
camera_bottom_img,
camera_head_img,
left_slave_arm,
right_slave_arm,
):
# 将ROS图像消息转换为NumPy数组
camera_left_np_img = np.frombuffer(
camera_left_img.data, dtype=np.uint8
).reshape(camera_left_img.height, camera_left_img.width, -1)
camera_right_np_img = np.frombuffer(
camera_right_img.data, dtype=np.uint8
).reshape(camera_right_img.height, camera_right_img.width, -1)
camera_bottom_np_img = np.frombuffer(
camera_bottom_img.data, dtype=np.uint8
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
camera_head_np_img = np.frombuffer(
camera_head_img.data, dtype=np.uint8
).reshape(camera_head_img.height, camera_head_img.width, -1)
left_slave_arm_angle = left_slave_arm.position
left_slave_arm_velocity = left_slave_arm.velocity
left_slave_arm_force = left_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
right_slave_arm_angle = right_slave_arm.position
right_slave_arm_velocity = right_slave_arm.velocity
right_slave_arm_force = right_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
# 收集数据
obs = collections.OrderedDict(
{
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
"qvel": np.concatenate(
[left_slave_arm_velocity, right_slave_arm_velocity]
),
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
"images": {
self.camera_names[0]: camera_head_np_img,
self.camera_names[1]: camera_bottom_np_img,
self.camera_names[2]: camera_left_np_img,
self.camera_names[3]: camera_right_np_img,
},
}
)
self.ts = dm_env.TimeStep(
step_type=(
dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID
),
reward=0.0,
discount=1.0,
observation=obs,
)
def reset(self):
left_joint_state = JointState()
left_joint_state.header.stamp = rospy.Time.now()
left_joint_state.name = ["joint_j"]
left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1]
right_joint_state = JointState()
right_joint_state.header.stamp = rospy.Time.now()
right_joint_state.name = ["joint_j"]
right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1]
self.left_slave_arm_pub.publish(left_joint_state)
self.right_slave_arm_pub.publish(right_joint_state)
while self.ts is None:
time.sleep(0.002)
return self.ts
def step(self, target_angle):
self.is_frist_step = False
left_joint_state = JointState()
left_joint_state.header.stamp = rospy.Time.now()
left_joint_state.name = ["joint_canfd"]
left_joint_state.position = target_angle[0 : self.arm_axis + 1]
# print("left_joint_state: ", left_joint_state)
right_joint_state = JointState()
right_joint_state.header.stamp = rospy.Time.now()
right_joint_state.name = ["joint_canfd"]
right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2]
# print("right_joint_state: ", right_joint_state)
self.left_slave_arm_pub.publish(left_joint_state)
self.right_slave_arm_pub.publish(right_joint_state)
# time.sleep(0.013)
return self.ts
def run(self):
rospy.loginfo("Starting ROS spin")
data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle])
self.reset()
# print("data: ", data)
while not rospy.is_shutdown():
self.step(data)
rospy.sleep(0.010)
if __name__ == "__main__":
synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml")
start_time = time.time()
synchronizer.run()

View File

@@ -0,0 +1,280 @@
#!/usr/bin/env python3
import os
import time
import h5py
import yaml
import rospy
import dm_env
import numpy as np
import collections
from datetime import datetime
from std_msgs.msg import Int32MultiArray
from sensor_msgs.msg import Image, JointState
from message_filters import Subscriber, ApproximateTimeSynchronizer
class DataCollector:
def __init__(
self,
dataset_dir,
dataset_name,
max_timesteps,
camera_names,
state_dim,
overwrite=False,
):
self.dataset_dir = dataset_dir
self.dataset_name = dataset_name
self.max_timesteps = max_timesteps
self.camera_names = camera_names
self.state_dim = state_dim
self.overwrite = overwrite
self.init_dict()
self.create_dataset_dir()
def init_dict(self):
self.data_dict = {
"/observations/qpos": [],
"/observations/qvel": [],
"/observations/effort": [],
"/action": [],
}
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"] = []
def create_dataset_dir(self):
# 按照年月日创建目录
date_str = datetime.now().strftime("%Y%m%d")
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
def create_file(self):
# 检查数据集名称是否存在,如果存在则递增名称
counter = 0
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
if not self.overwrite:
while os.path.exists(dataset_path + ".hdf5"):
dataset_path = os.path.join(
self.dataset_dir, f"{self.dataset_name}_{counter}"
)
counter += 1
self.dataset_path = dataset_path
def collect_data(self, ts, action):
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
self.data_dict["/observations/effort"].append(ts.observation["effort"])
self.data_dict["/action"].append(action)
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"].append(
ts.observation["images"][cam_name]
)
def save_data(self):
self.create_file()
t0 = time.time()
# 保存数据
with h5py.File(
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
) as root:
root.attrs["sim"] = False
obs = root.create_group("observations")
image = obs.create_group("images")
for cam_name in self.camera_names:
_ = image.create_dataset(
cam_name,
(self.max_timesteps, 480, 640, 3),
dtype="uint8",
chunks=(1, 480, 640, 3),
)
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
for name, array in self.data_dict.items():
root[name][...] = array
print(f"Saving: {time.time() - t0:.1f} secs")
return True
class DataSynchronizer:
def __init__(self, config_path="config"):
rospy.init_node("synchronizer", anonymous=True)
rospy.loginfo("ROS node initialized")
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.arm_axis = config["arm_axis"]
# 创建订阅者
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
self.camera_bottom_sub = Subscriber(
config["ros_topics"]["camera_bottom"], Image
)
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
self.left_master_arm_sub = Subscriber(
config["ros_topics"]["left_master_arm"], JointState
)
self.left_slave_arm_sub = Subscriber(
config["ros_topics"]["left_slave_arm"], JointState
)
self.right_master_arm_sub = Subscriber(
config["ros_topics"]["right_master_arm"], JointState
)
self.right_slave_arm_sub = Subscriber(
config["ros_topics"]["right_slave_arm"], JointState
)
self.left_aloha_state_pub = rospy.Subscriber(
config["ros_topics"]["left_aloha_state"],
Int32MultiArray,
self.aloha_state_callback,
)
rospy.loginfo("Subscribers created")
self.camera_names = config["camera_names"]
# 创建同步器
self.ats = ApproximateTimeSynchronizer(
[
self.camera_left_sub,
self.camera_right_sub,
self.camera_bottom_sub,
self.camera_head_sub,
self.left_master_arm_sub,
self.left_slave_arm_sub,
self.right_master_arm_sub,
self.right_slave_arm_sub,
],
queue_size=1,
slop=0.05,
)
self.ats.registerCallback(self.callback)
rospy.loginfo("Time synchronizer created and callback registered")
self.data_collector = DataCollector(
dataset_dir=config["dataset_dir"],
dataset_name=config["dataset_name"],
max_timesteps=config["max_timesteps"],
camera_names=config["camera_names"],
state_dim=config["state_dim"],
overwrite=config["overwrite"],
)
self.timesteps_collected = 0
self.begin_collect = False
self.last_time = None
def callback(
self,
camera_left_img,
camera_right_img,
camera_bottom_img,
camera_head_img,
left_master_arm,
left_slave_arm,
right_master_arm,
right_slave_arm,
):
if self.begin_collect:
self.timesteps_collected += 1
rospy.loginfo(
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
)
else:
self.timesteps_collected = 0
return
if self.timesteps_collected == 0:
return
current_time = time.time()
if self.last_time is not None:
frequency = 1.0 / (current_time - self.last_time)
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
self.last_time = current_time
# 将ROS图像消息转换为NumPy数组
camera_left_np_img = np.frombuffer(
camera_left_img.data, dtype=np.uint8
).reshape(camera_left_img.height, camera_left_img.width, -1)
camera_right_np_img = np.frombuffer(
camera_right_img.data, dtype=np.uint8
).reshape(camera_right_img.height, camera_right_img.width, -1)
camera_bottom_np_img = np.frombuffer(
camera_bottom_img.data, dtype=np.uint8
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
camera_head_np_img = np.frombuffer(
camera_head_img.data, dtype=np.uint8
).reshape(camera_head_img.height, camera_head_img.width, -1)
# 提取臂的角度,速度,力
left_master_arm_angle = left_master_arm.position
# left_master_arm_velocity = left_master_arm.velocity
# left_master_arm_force = left_master_arm.effort
left_slave_arm_angle = left_slave_arm.position
left_slave_arm_velocity = left_slave_arm.velocity
left_slave_arm_force = left_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
right_master_arm_angle = right_master_arm.position
# right_master_arm_velocity = right_master_arm.velocity
# right_master_arm_force = right_master_arm.effort
right_slave_arm_angle = right_slave_arm.position
right_slave_arm_velocity = right_slave_arm.velocity
right_slave_arm_force = right_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
# 收集数据
obs = collections.OrderedDict(
{
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
"qvel": np.concatenate(
[left_slave_arm_velocity, right_slave_arm_velocity]
),
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
"images": {
self.camera_names[0]: camera_head_np_img,
self.camera_names[1]: camera_bottom_np_img,
self.camera_names[2]: camera_left_np_img,
self.camera_names[3]: camera_right_np_img,
},
}
)
print(self.camera_names[0])
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
)
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
self.data_collector.collect_data(ts, action)
# 检查是否收集了足够的数据
if self.timesteps_collected >= self.data_collector.max_timesteps:
self.data_collector.save_data()
rospy.loginfo("Data collection complete")
self.data_collector.init_dict()
self.begin_collect = False
self.timesteps_collected = 0
def aloha_state_callback(self, data):
if not self.begin_collect:
self.aloha_state = data.data
print(self.aloha_state[0], self.aloha_state[1])
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
self.begin_collect = True
def run(self):
rospy.loginfo("Starting ROS spin")
rospy.spin()
if __name__ == "__main__":
synchronizer = DataSynchronizer(
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
)
synchronizer.run()

View File

@@ -0,0 +1,63 @@
<launch>
<!-- 左从臂节点 -->
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
<param name="aloha_state_topic" value="/left_slave_arm_aloha_state" type= "string"/>
<param name="hz" value="120" type= "int"/>
</node>
<!-- 右从臂节点 -->
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
<param name="aloha_state_topic" value="/right_slave_arm_aloha_state" type= "string"/>
<param name="hz" value="120" type= "int"/>
</node>
<!-- 左主臂节点 -->
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右主臂节点 -->
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右臂相机节点 -->
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="151222072576" type= "string"/>
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左臂相机节点 -->
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="150622070125" type= "string"/>
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 顶部相机节点 -->
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="241122071186" type= "string"/>
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 底部相机节点 -->
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="152122078546" type= "string"/>
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
</launch>

View File

@@ -0,0 +1,49 @@
<launch>
<!-- 左从臂节点 -->
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
<param name="joint_actions_topic" value="/left_slave_arm_joint_actions" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右从臂节点 -->
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
<param name="joint_actions_topic" value="/right_slave_arm_joint_actions" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右臂相机节点 -->
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="151222072576" type= "string"/>
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左臂相机节点 -->
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="150622070125" type= "string"/>
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 顶部相机节点 -->
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="241122071186" type= "string"/>
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 底部相机节点 -->
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="152122078546" type= "string"/>
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
</launch>

View File

@@ -0,0 +1,61 @@
<launch>
<!-- 左从臂节点 -->
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 右从臂节点 -->
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左主臂节点 -->
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 右主臂节点 -->
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 右臂相机节点 -->
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="216322070299" type= "string"/>
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左臂相机节点 -->
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="216322074992" type= "string"/>
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 顶部相机节点 -->
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="215322076086" type= "string"/>
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 底部相机节点 -->
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="215222074360" type= "string"/>
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
</launch>

View File

@@ -0,0 +1,284 @@
#!/usr/bin/env python3
import os
import cv2
import time
import h5py
import yaml
import json
import rospy
import dm_env
import socket
import numpy as np
import collections
from datetime import datetime
from std_msgs.msg import Int32MultiArray
from sensor_msgs.msg import Image, JointState
from message_filters import Subscriber, ApproximateTimeSynchronizer
class DataCollector:
def __init__(
self,
dataset_dir,
dataset_name,
max_timesteps,
camera_names,
state_dim,
overwrite=False,
):
self.dataset_dir = dataset_dir
self.dataset_name = dataset_name
self.max_timesteps = max_timesteps
self.camera_names = camera_names
self.state_dim = state_dim
self.overwrite = overwrite
self.init_dict()
self.create_dataset_dir()
def init_dict(self):
self.data_dict = {
"/observations/qpos": [],
"/observations/qvel": [],
"/observations/effort": [],
"/action": [],
}
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"] = []
def create_dataset_dir(self):
# 按照年月日创建目录
date_str = datetime.now().strftime("%Y%m%d")
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
def create_file(self):
# 检查数据集名称是否存在,如果存在则递增名称
counter = 0
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
if not self.overwrite:
while os.path.exists(dataset_path + ".hdf5"):
dataset_path = os.path.join(
self.dataset_dir, f"{self.dataset_name}_{counter}"
)
counter += 1
self.dataset_path = dataset_path
def collect_data(self, ts, action):
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
self.data_dict["/observations/effort"].append(ts.observation["effort"])
self.data_dict["/action"].append(action)
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"].append(
ts.observation["images"][cam_name]
)
def save_data(self):
self.create_file()
t0 = time.time()
# 保存数据
with h5py.File(
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
) as root:
root.attrs["sim"] = False
obs = root.create_group("observations")
image = obs.create_group("images")
for cam_name in self.camera_names:
_ = image.create_dataset(
cam_name,
(self.max_timesteps, 480, 640, 3),
dtype="uint8",
chunks=(1, 480, 640, 3),
)
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
for name, array in self.data_dict.items():
root[name][...] = array
print(f"Saving: {time.time() - t0:.1f} secs")
return True
class DataSynchronizer:
def __init__(self, config_path="config"):
rospy.init_node("synchronizer", anonymous=True)
rospy.loginfo("ROS node initialized")
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.arm_axis = config["arm_axis"]
# 创建订阅者
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
self.camera_bottom_sub = Subscriber(
config["ros_topics"]["camera_bottom"], Image
)
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
self.left_master_arm_sub = Subscriber(
config["ros_topics"]["left_master_arm"], JointState
)
self.left_slave_arm_sub = Subscriber(
config["ros_topics"]["left_slave_arm"], JointState
)
self.right_master_arm_sub = Subscriber(
config["ros_topics"]["right_master_arm"], JointState
)
self.right_slave_arm_sub = Subscriber(
config["ros_topics"]["right_slave_arm"], JointState
)
self.left_aloha_state_pub = rospy.Subscriber(
config["ros_topics"]["left_aloha_state"],
Int32MultiArray,
self.aloha_state_callback,
)
rospy.loginfo("Subscribers created")
# 创建同步器
self.ats = ApproximateTimeSynchronizer(
[
self.camera_left_sub,
self.camera_right_sub,
self.camera_bottom_sub,
self.camera_head_sub,
self.left_master_arm_sub,
self.left_slave_arm_sub,
self.right_master_arm_sub,
self.right_slave_arm_sub,
],
queue_size=1,
slop=0.05,
)
self.ats.registerCallback(self.callback)
rospy.loginfo("Time synchronizer created and callback registered")
self.data_collector = DataCollector(
dataset_dir=config["dataset_dir"],
dataset_name=config["dataset_name"],
max_timesteps=config["max_timesteps"],
camera_names=config["camera_names"],
state_dim=config["state_dim"],
overwrite=config["overwrite"],
)
self.timesteps_collected = 0
self.begin_collect = False
self.last_time = None
def callback(
self,
camera_left_img,
camera_right_img,
camera_bottom_img,
camera_head_img,
left_master_arm,
left_slave_arm,
right_master_arm,
right_slave_arm,
):
if self.begin_collect:
self.timesteps_collected += 1
rospy.loginfo(
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
)
else:
self.timesteps_collected = 0
return
if self.timesteps_collected == 0:
return
current_time = time.time()
if self.last_time is not None:
frequency = 1.0 / (current_time - self.last_time)
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
self.last_time = current_time
# 将ROS图像消息转换为NumPy数组
camera_left_np_img = np.frombuffer(
camera_left_img.data, dtype=np.uint8
).reshape(camera_left_img.height, camera_left_img.width, -1)
camera_right_np_img = np.frombuffer(
camera_right_img.data, dtype=np.uint8
).reshape(camera_right_img.height, camera_right_img.width, -1)
camera_bottom_np_img = np.frombuffer(
camera_bottom_img.data, dtype=np.uint8
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
camera_head_np_img = np.frombuffer(
camera_head_img.data, dtype=np.uint8
).reshape(camera_head_img.height, camera_head_img.width, -1)
# 提取臂的角度,速度,力
left_master_arm_angle = left_master_arm.position
# left_master_arm_velocity = left_master_arm.velocity
# left_master_arm_force = left_master_arm.effort
left_slave_arm_angle = left_slave_arm.position
left_slave_arm_velocity = left_slave_arm.velocity
left_slave_arm_force = left_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
right_master_arm_angle = right_master_arm.position
# right_master_arm_velocity = right_master_arm.velocity
# right_master_arm_force = right_master_arm.effort
right_slave_arm_angle = right_slave_arm.position
right_slave_arm_velocity = right_slave_arm.velocity
right_slave_arm_force = right_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
# 收集数据
obs = collections.OrderedDict(
{
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
"qvel": np.concatenate(
[left_slave_arm_velocity, right_slave_arm_velocity]
),
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
"images": {
"cam_front": camera_head_np_img,
"cam_low": camera_bottom_np_img,
"cam_left": camera_left_np_img,
"cam_right": camera_right_np_img,
},
}
)
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
)
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
self.data_collector.collect_data(ts, action)
# 检查是否收集了足够的数据
if self.timesteps_collected >= self.data_collector.max_timesteps:
self.data_collector.save_data()
rospy.loginfo("Data collection complete")
self.data_collector.init_dict()
self.begin_collect = False
self.timesteps_collected = 0
def aloha_state_callback(self, data):
if not self.begin_collect:
self.aloha_state = data.data
print(self.aloha_state[0], self.aloha_state[1])
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
self.begin_collect = True
def run(self):
rospy.loginfo("Starting ROS spin")
rospy.spin()
if __name__ == "__main__":
synchronizer = DataSynchronizer(
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
)
synchronizer.run()

View File

@@ -0,0 +1,31 @@
<?xml version="1.0"?>
<package format="2">
<name>shadow_rm_aloha</name>
<version>0.0.1</version>
<description>The shadow_rm_aloha package</description>
<maintainer email="your_email@example.com">Your Name</maintainer>
<license>TODO</license>
<buildtool_depend>catkin</buildtool_depend>
<build_depend>rospy</build_depend>
<build_depend>sensor_msgs</build_depend>
<build_depend>std_msgs</build_depend>
<build_depend>cv_bridge</build_depend>
<build_depend>image_transport</build_depend>
<build_depend>message_generation</build_depend>
<build_depend>message_runtime</build_depend>
<exec_depend>rospy</exec_depend>
<exec_depend>sensor_msgs</exec_depend>
<exec_depend>std_msgs</exec_depend>
<exec_depend>cv_bridge</exec_depend>
<exec_depend>image_transport</exec_depend>
<exec_depend>message_runtime</exec_depend>
<export>
</export>
</package>

View File

@@ -0,0 +1,5 @@
# GetArmStatus.srv
---
sensor_msgs/JointState joint_status

View File

@@ -0,0 +1,4 @@
# GetImage.srv
---
bool success
sensor_msgs/Image image

View File

@@ -0,0 +1,4 @@
# MoveArm.srv
float32[] joint_angle
---
bool success

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,49 @@
import multiprocessing as mp
import time
def collect_data(arm_id, cam_id, data_queue, lock):
while True:
# 模拟数据采集
arm_data = f"Arm {arm_id} data"
cam_data = f"Cam {cam_id} data"
# 获取当前时间戳
timestamp = time.time()
# 将数据放入队列
with lock:
data_queue.put((timestamp, arm_data, cam_data))
# 模拟高帧率
time.sleep(0.01)
def main():
num_arms = 4
num_cams = 4
# 创建队列和锁
data_queue = mp.Queue()
lock = mp.Lock()
# 创建进程
processes = []
for i in range(num_arms):
p = mp.Process(target=collect_data, args=(i, i, data_queue, lock))
processes.append(p)
p.start()
# 主进程处理数据
try:
while True:
if not data_queue.empty():
with lock:
timestamp, arm_data, cam_data = data_queue.get()
print(f"Timestamp: {timestamp}, {arm_data}, {cam_data}")
except KeyboardInterrupt:
for p in processes:
p.terminate()
for p in processes:
p.join()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import os
import shutil
from datetime import datetime
from shadow_rm_aloha.data_sub_process.aloha_data_synchronizer import DataCollector
def test_create_dataset_dir():
# 设置测试参数
dataset_dir = './test_data/dataset'
dataset_name = 'test_episode'
max_timesteps = 100
camera_names = ['cam1', 'cam2']
overwrite = False
# 清理旧的测试数据
if os.path.exists(dataset_dir):
shutil.rmtree(dataset_dir)
# 创建 DataCollector 实例并调用 create_dataset_dir
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, overwrite)
# 检查目录是否按预期创建
date_str = datetime.now().strftime("%Y%m%d")
expected_dir = os.path.join(dataset_dir, date_str)
assert os.path.exists(expected_dir), f"Expected directory {expected_dir} does not exist."
# 检查文件名是否按预期递增
expected_file = os.path.join(expected_dir, dataset_name + '.hdf5')
assert collector.dataset_path == expected_file, f"Expected file path {expected_file}, but got {collector.dataset_path}"
# 再次调用 create_dataset_dir检查文件名是否递增
# collector.create_dataset_dir()
expected_file_incremented = os.path.join(expected_dir, dataset_name + '_1.hdf5')
assert collector.dataset_path == expected_file_incremented, f"Expected file path {expected_file_incremented}, but got {collector.dataset_path}"
print("All tests passed.")
if __name__ == '__main__':
test_create_dataset_dir()

View File

@@ -0,0 +1,105 @@
import multiprocessing
import time
import random
import socket
import json
import logging
# 设置日志级别
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class test_udp():
def __init__(self):
arm_ip = '192.168.1.19'
arm_port = 8080
self.arm =socket.socket()
self.arm.connect((arm_ip, arm_port))
set_udp = {"command":"set_realtime_push","cycle":1,"enable":True,"port":8090,"ip":"192.168.1.101","custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
self.arm.send(json.dumps(set_udp).encode('utf-8'))
state = self.arm.recv(1024)
logging.info(f"Send data to {arm_ip}:{arm_port}: {state}")
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 设置套接字选项,允许端口复用
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
local_ip = "192.168.1.101"
local_port = 8090
self.udp_socket.bind((local_ip, local_port))
self.BUFFER_SIZE = 1024
def set_udp(self):
while True:
start_time = time.time()
data, addr = self.udp_socket.recvfrom(self.BUFFER_SIZE)
# 将接收到的UDP数据解码并解析为JSON
data = json.loads(data.decode('utf-8'))
end_time = time.time()
print(f"Received data {data}")
udp_socket.close()
def collect_arm_data(arm_id, queue, event):
while True:
data = f"Arm {arm_id} data {random.random()}"
queue.put((arm_id, data))
event.set()
time.sleep(1)
def collect_camera_data(camera_id, queue, event):
while True:
data = f"Camera {camera_id} data {random.random()}"
queue.put((camera_id, data))
event.set()
time.sleep(1)
def main():
arm_queues = [multiprocessing.Queue() for _ in range(4)]
camera_queues = [multiprocessing.Queue() for _ in range(4)]
arm_events = [multiprocessing.Event() for _ in range(4)]
camera_events = [multiprocessing.Event() for _ in range(4)]
arm_processes = [multiprocessing.Process(target=collect_arm_data, args=(i, arm_queues[i], arm_events[i])) for i in range(4)]
camera_processes = [multiprocessing.Process(target=collect_camera_data, args=(i, camera_queues[i], camera_events[i])) for i in range(4)]
for p in arm_processes + camera_processes:
p.start()
try:
while True:
for event in arm_events + camera_events:
event.wait()
for i in range(4):
if not arm_queues[i].empty():
arm_id, arm_data = arm_queues[i].get()
print(f"Received from Arm {arm_id}: {arm_data}")
arm_events[i].clear()
if not camera_queues[i].empty():
camera_id, camera_data = camera_queues[i].get()
print(f"Received from Camera {camera_id}: {camera_data}")
camera_events[i].clear()
time.sleep(0.1)
except KeyboardInterrupt:
for p in arm_processes + camera_processes:
p.terminate()
if __name__ == "__main__":
main()
# if __name__ == "__main__":
# test_udp = test_udp()
# test_udp.set_udp()

View File

@@ -0,0 +1,4 @@
__pycache__/
*.pyc
*.pyo
*.pt

Some files were not shown because too many files have changed in this diff Show More