add realman shadow src

This commit is contained in:
2025-06-07 11:29:43 +08:00
parent e079566597
commit cf8df17d3a
98 changed files with 14215 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
__pycache__/
build/
devel/
dist/
data/
.catkin_workspace
*.pyc
*.pyo
*.pt
.vscode/

View File

@@ -0,0 +1,3 @@
# 默认忽略的文件
/shelf/
/workspace.xml

View File

@@ -0,0 +1 @@
aloha_data_synchronizer.py

View File

@@ -0,0 +1,17 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="4">
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
<item index="1" class="java.lang.String" itemvalue="thop" />
<item index="2" class="java.lang.String" itemvalue="torch" />
<item index="3" class="java.lang.String" itemvalue="torchvision" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.11 (随箱软件)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
</project>

View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
</modules>
</component>
</project>

View File

@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@@ -0,0 +1,38 @@
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
dataset_name: 'episode'
max_timesteps: 500
state_dim: 14
overwrite: False
arm_axis: 6
camera_names:
- 'cam_high'
- 'cam_low'
- 'cam_left'
- 'cam_right'
ros_topics:
camera_left: '/camera_left/rgb/image_raw'
camera_right: '/camera_right/rgb/image_raw'
camera_bottom: '/camera_bottom/rgb/image_raw'
camera_head: '/camera_head/rgb/image_raw'
left_master_arm: '/left_master_arm_joint_states'
left_slave_arm: '/left_slave_arm_joint_states'
right_master_arm: '/right_master_arm_joint_states'
right_slave_arm: '/right_slave_arm_joint_states'
left_aloha_state: '/left_slave_arm_aloha_state'
right_aloha_state: '/right_slave_arm_aloha_state'
robot_env: {
# TODO change the path to the correct one
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
arm_axis: 6,
head_camera: '241122071186',
bottom_camera: '152122078546',
left_camera: '150622070125',
right_camera: '151222072576',
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
}

View File

@@ -0,0 +1,9 @@
arm_ip: "192.168.1.18"
arm_port: 8080
arm_axis: 6
local_ip: "192.168.1.101"
local_port: 8089
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
get_vel: True
get_torque: True

View File

@@ -0,0 +1,9 @@
arm_ip: "192.168.1.19"
arm_port: 8080
arm_axis: 6
local_ip: "192.168.1.101"
local_port: 8090
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
get_vel: True
get_torque: True

View File

@@ -0,0 +1,4 @@
port: /dev/ttyUSB1
baudrate: 460800
hex_data: "55 AA 02 00 00 67"
arm_axis: 6

View File

@@ -0,0 +1,4 @@
port: /dev/ttyUSB0
baudrate: 460800
hex_data: "55 AA 02 00 00 67"
arm_axis: 6

View File

@@ -0,0 +1,6 @@
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
dataset_name: 'episode'
episode_idx: 1
FPS: 30
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints

View File

@@ -0,0 +1,39 @@
[tool.poetry]
name = "shadow_rm_aloha"
version = "0.1.1"
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
readme = "README.md"
authors = ["Shadow <qiuchengzhan@gmail.com>"]
license = "MIT"
# include = ["realman_vision/pytransform/_pytransform.so",]
classifiers = [
"Operating System :: POSIX :: Linux amd64",
"Programming Language :: Python :: 3.10",
]
[tool.poetry.dependencies]
python = ">=3.10"
matplotlib = ">=3.9.2"
h5py = ">=3.12.1"
# rospy = ">=1.17.0"
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
pytest = ">=8.3"
black = ">=24.10.0"
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
[tool.poetry.group.dev.dependencies]
[build-system]
requires = ["poetry-core>=1.8.4"]
build-backend = "poetry.core.masonry.api"

View File

@@ -0,0 +1,42 @@
cmake_minimum_required(VERSION 3.0.2)
project(shadow_rm_aloha)
find_package(catkin REQUIRED COMPONENTS
rospy
sensor_msgs
cv_bridge
image_transport
std_msgs
message_generation
)
add_service_files(
FILES
GetArmStatus.srv
GetImage.srv
MoveArm.srv
)
generate_messages(
DEPENDENCIES
sensor_msgs
std_msgs
)
catkin_package(
CATKIN_DEPENDS message_runtime rospy std_msgs
)
include_directories(
${catkin_INCLUDE_DIRS}
)
install(PROGRAMS
arm_node/slave_arm_publisher.py
arm_node/master_arm_publisher.py
arm_node/slave_arm_pub_sub.py
camera_node/camera_publisher.py
data_sub_process/aloha_data_synchronizer.py
data_sub_process/aloha_data_collect.py
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
)

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
import rospy
import logging
from shadow_rm_robot.servo_robotic_arm import ServoArm
from sensor_msgs.msg import JointState
# 配置日志记录
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class MasterArmPublisher:
def __init__(self):
rospy.init_node("master_arm_publisher", anonymous=True)
arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml")
hz = rospy.get_param("~hz", 250)
self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
self.arm = ServoArm(arm_config)
self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1)
self.rate = rospy.Rate(hz) # 30 Hz
def publish_joint_states(self):
while not rospy.is_shutdown():
joint_state = JointState()
joint_pos = self.arm.get_joint_actions()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = list(joint_pos.keys())
joint_state.position = list(joint_pos.values())
joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改)
joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改)
# rospy.loginfo(f"{self.joint_states_topic}: {joint_state}")
self.publisher.publish(joint_state)
self.rate.sleep()
if __name__ == "__main__":
try:
arm_publisher = MasterArmPublisher()
arm_publisher.publish_joint_states()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,63 @@
#!/usr/bin/env python3
import rospy
import logging
from shadow_rm_robot.realman_arm import RmArm
from sensor_msgs.msg import JointState
# 配置日志记录
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class SlaveArmPublisher:
def __init__(self):
rospy.init_node("slave_arm_publisher", anonymous=True)
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
hz = rospy.get_param("~hz", 250)
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions")
self.arm = RmArm(arm_config)
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback)
self.rate = rospy.Rate(hz)
def publish_joint_states(self):
while not rospy.is_shutdown():
joint_state = JointState()
data = self.arm.get_integrate_data()
# data = self.arm.get_arm_data()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
# joint_state.position = data["joint_angle"]
joint_state.position = data['arm_angle']
# joint_state.position = list(data["arm_angle"])
# joint_state.velocity = list(data["arm_velocity"])
# joint_state.effort = list(data["arm_torque"])
# rospy.loginfo(f"joint_states_topic: {joint_state}")
self.publisher.publish(joint_state)
self.rate.sleep()
def callback(self, data):
# rospy.loginfo(f"Received joint_states_topic: {data}")
start_time = rospy.Time.now()
if data is None:
return
if data.name == ["joint_canfd"]:
self.arm.set_joint_canfd_position(data.position[0:6])
elif data.name == ["joint_j"]:
self.arm.set_joint_position(data.position[0:6])
# self.arm.set_gripper_position(data.position[6])
end_time = rospy.Time.now()
time_cost_ms = (end_time - start_time).to_sec() * 1000
rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}")
if __name__ == "__main__":
try:
arm_publisher = SlaveArmPublisher()
arm_publisher.publish_joint_states()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,44 @@
#!/usr/bin/env python3
import rospy
import logging
from shadow_rm_robot.realman_arm import RmArm
from sensor_msgs.msg import JointState
from std_msgs.msg import Int32MultiArray
# 配置日志记录
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class SlaveArmPublisher:
def __init__(self):
rospy.init_node("slave_arm_publisher", anonymous=True)
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
hz = rospy.get_param("~hz", 250)
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state")
self.arm = RmArm(arm_config)
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1)
self.rate = rospy.Rate(hz)
def publish_joint_states(self):
while not rospy.is_shutdown():
joint_state = JointState()
data = self.arm.get_integrate_data()
joint_state.header.stamp = rospy.Time.now()
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
joint_state.position = list(data["arm_angle"])
joint_state.velocity = list(data["arm_velocity"])
joint_state.effort = list(data["arm_torque"])
# rospy.loginfo(f"joint_states_topic: {joint_state}")
self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values()))
self.publisher.publish(joint_state)
self.rate.sleep()
if __name__ == "__main__":
try:
arm_publisher = SlaveArmPublisher()
arm_publisher.publish_joint_states()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,69 @@
#!/usr/bin/env python3
import rospy
from sensor_msgs.msg import Image
from std_msgs.msg import Header
import numpy as np
from shadow_camera.realsense import RealSenseCamera
class CameraPublisher:
def __init__(self):
rospy.init_node('camera_publisher', anonymous=True)
self.serial_number = rospy.get_param('~serial_number', None)
hz = rospy.get_param('~hz', 30)
rospy.loginfo(f"Serial number: {self.serial_number}")
self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw')
self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw')
rospy.loginfo(f"RGB topic: {self.rgb_topic}")
rospy.loginfo(f"Depth topic: {self.depth_topic}")
self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10)
# self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10)
self.rate = rospy.Rate(hz) # 30 Hz
self.camera = RealSenseCamera(self.serial_number, False)
rospy.loginfo("Camera initialized")
def publish_images(self):
self.camera.start_camera()
rospy.loginfo("Camera started")
while not rospy.is_shutdown():
result = self.camera.read_frame(True, False, False, False)
if result is None:
rospy.logerr("Failed to read frame from camera")
continue
color_image, depth_image, _, _ = result
if color_image is not None or depth_image is not None:
rgb_msg = self.create_image_msg(color_image, "bgr8")
# depth_msg = self.create_image_msg(depth_image, "mono16")
self.rgb_pub.publish(rgb_msg)
# self.depth_pub.publish(depth_msg)
# rospy.loginfo("Published RGB image")
else:
rospy.logwarn("Received None for color_image or depth_image")
self.rate.sleep()
self.camera.stop_camera()
rospy.loginfo("Camera stopped")
def create_image_msg(self, image, encoding):
msg = Image()
msg.header = Header()
msg.header.stamp = rospy.Time.now()
msg.height, msg.width = image.shape[:2]
msg.encoding = encoding
msg.is_bigendian = False
msg.step = image.strides[0]
msg.data = np.array(image).tobytes()
return msg
if __name__ == '__main__':
try:
camera_publisher = CameraPublisher()
camera_publisher.publish_images()
except rospy.ROSInterruptException:
pass

View File

@@ -0,0 +1,112 @@
import os
import time
import h5py
import logging
from datetime import datetime
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class DataCollector:
def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False):
self.arm_axis = 7
self.dataset_dir = dataset_dir
self.dataset_name = dataset_name
self.max_timesteps = max_timesteps
self.camera_names = camera_names
self.state_dim = state_dim
self.overwrite = overwrite
self.data_dict = {
'/observations/qpos': [],
'/observations/qvel': [],
'/observations/effort': [],
'/action': [],
}
for cam_name in camera_names:
self.data_dict[f'/observations/images/{cam_name}'] = []
# 自动检测和创建数据集目录
self.create_dataset_dir()
self.timesteps_collected = 0
def create_dataset_dir(self):
# 按照年月日创建目录
date_str = datetime.now().strftime("%Y%m%d")
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
def collect_data(self, ts, action):
self.data_dict['/observations/qpos'].append(ts.observation['qpos'])
self.data_dict['/observations/qvel'].append(ts.observation['qvel'])
self.data_dict['/observations/effort'].append(ts.observation['effort'])
self.data_dict['/action'].append(action)
for cam_name in self.camera_names:
self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
def save_data(self):
t0 = time.time()
# 保存数据
with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root:
root.attrs['sim'] = False
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in self.camera_names:
_ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3))
_ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim))
_ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim))
_ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim))
_ = root.create_dataset('action', (self.max_timesteps, self.state_dim))
for name, array in self.data_dict.items():
root[name][...] = array
print(f'Saving: {time.time() - t0:.1f} secs')
return True
def load_hdf5(self, orign_path, file):
self.dataset_path = os.path.join(self.dataset_dir, file)
if not os.path.isfile(orign_path):
logging.error(f'Dataset does not exist at {orign_path}')
exit()
with h5py.File(orign_path, 'r') as root:
self.is_sim = root.attrs['sim']
self.qpos = root['/observations/qpos'][()]
self.qvel = root['/observations/qvel'][()]
self.effort = root['/observations/effort'][()]
self.action = root['/action'][()]
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
for cam_name in root[f'/observations/images/'].keys()}
self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis]
self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1]
self.data_dict['/observations/qpos'] = self.qpos
self.data_dict['/observations/qvel'] = self.qvel
self.data_dict['/observations/effort'] = self.effort
self.data_dict['/action'] = self.action
for cam_name in self.camera_names:
self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name]
return True
if __name__ == '__main__':
"""
用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据
"""
dataset_dir = '/home/wang/project/shadow_rm_aloha/data'
orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128'
dataset_name = 'test'
max_timesteps = 300
camera_names = ['cam_high','cam_low','cam_left','cam_right']
state_dim = 16
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
for file in os.listdir(orign_dir):
collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
orign_path = os.path.join(orign_dir, file)
print(orign_path)
collector.load_hdf5(orign_path, file)
collector.save_data()

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env python3
import os
import numpy as np
import h5py
import yaml
import logging
import time
from shadow_rm_robot.realman_arm import RmArm
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
class DataValidator:
def __init__(self, config):
self.dataset_dir = config['dataset_dir']
self.episode_idx = config['episode_idx']
self.joint_names = config['joint_names']
self.dataset_name = f'episode_{self.episode_idx}'
self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5')
self.state_names = self.joint_names + ["gripper"]
self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml')
def load_hdf5(self):
if not os.path.isfile(self.dataset_path):
logging.error(f'Dataset does not exist at {self.dataset_path}')
exit()
with h5py.File(self.dataset_path, 'r') as root:
self.is_sim = root.attrs['sim']
self.qpos = root['/observations/qpos'][()]
# self.qvel = root['/observations/qvel'][()]
# self.effort = root['/observations/effort'][()]
self.action = root['/action'][()]
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
for cam_name in root[f'/observations/images/'].keys()}
def validate_data(self):
# 验证位置数据
if not self.qpos.shape[1] == 14:
logging.error('qpos shape does not match expected number of joints')
return False
logging.info('Data validation passed')
return True
def control_robot(self):
self.arm.set_joint_position(self.qpos[0][0:6])
for qpos in self.qpos:
logging.info(f'qpos: {qpos}')
self.arm.set_joint_canfd_position(qpos[7:13])
self.arm.set_gripper_position(qpos[13])
time.sleep(0.035)
def run(self):
self.load_hdf5()
if self.validate_data():
self.control_robot()
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
if __name__ == '__main__':
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
validator = DataValidator(config)
validator.run()

View File

@@ -0,0 +1,147 @@
#!/usr/bin/env python3
import os
import numpy as np
import cv2
import h5py
import yaml
import logging
import matplotlib.pyplot as plt
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
class DataVisualizer:
def __init__(self, config):
self.dataset_dir = config['dataset_dir']
self.episode_idx = config['episode_idx']
self.dt = 1/config['FPS']
self.joint_names = config['joint_names']
self.state_names = self.joint_names + ["gripper"]
# self.camera_names = config['camera_names']
def join_file_path(self, file_name):
self.dataset_path = os.path.join(self.dataset_dir, file_name)
def load_hdf5(self):
if not os.path.isfile(self.dataset_path):
logging.error(f'Dataset does not exist at {self.dataset_path}')
exit()
with h5py.File(self.dataset_path, 'r') as root:
self.is_sim = root.attrs['sim']
self.qpos = root['/observations/qpos'][()]
# self.qvel = root['/observations/qvel'][()]
# self.effort = root['/observations/effort'][()]
self.action = root['/action'][()]
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
for cam_name in root[f'/observations/images/'].keys()}
def save_videos(self, video, dt, video_path=None):
if isinstance(video, list):
cam_names = list(video[0].keys())
h, w, _ = video[0][cam_names[0]].shape
w = w * len(cam_names)
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for image_dict in video:
images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names]
out.write(np.concatenate(images, axis=1))
out.release()
logging.info(f'Saved video to: {video_path}')
elif isinstance(video, dict):
cam_names = list(video.keys())
all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2)
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
out.write(all_cam_videos[t][:, :, [2, 1, 0]])
out.release()
logging.info(f'Saved video to: {video_path}')
def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None):
label1, label2 = label_overwrite if label_overwrite else ('State', 'Command')
qpos = np.array(qpos_list)
command = np.array(command_list)
num_ts, num_dim = qpos.shape
logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}')
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(qpos[:, dim_idx], label=label1)
ax.plot(command[:, dim_idx], label=label2)
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
if ylim:
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
logging.info(f'Saved qpos plot to: {plot_path}')
plt.close()
def visualize_single(self, data_list, label, plot_path, ylim=None):
data = np.array(data_list)
num_ts, num_dim = data.shape
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(data[:, dim_idx], label=label)
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
if ylim:
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
logging.info(f'Saved {label} plot to: {plot_path}')
plt.close()
def visualize_timestamp(self, t_list):
plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png')
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list])
axs[0].plot(np.arange(len(t_float)), t_float)
axs[0].set_title('Camera frame timestamps')
axs[0].set_xlabel('timestep')
axs[0].set_ylabel('time (sec)')
axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:])
axs[1].set_title('dt')
axs[1].set_xlabel('timestep')
axs[1].set_ylabel('time (sec)')
plt.tight_layout()
plt.savefig(plot_path)
logging.info(f'Saved timestamp plot to: {plot_path}')
plt.close()
def run(self):
for file_name in os.listdir(self.dataset_dir):
if file_name.endswith('.hdf5'):
self.join_file_path(file_name)
self.load_hdf5()
video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4'))
self.save_videos(self.image_dict, self.dt, video_path)
qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png'))
self.visualize_joints(self.qpos, self.action, qpos_plot_path)
# effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png'))
# self.visualize_single(self.effort, 'effort', effort_plot_path)
# error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png'))
# self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path)
# self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back
def load_config(config_path):
with open(config_path, 'r') as file:
return yaml.safe_load(file)
if __name__ == '__main__':
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
visualizer = DataVisualizer(config)
visualizer.run()

View File

@@ -0,0 +1,180 @@
#!/usr/bin/env python3
import time
import yaml
import rospy
import dm_env
import numpy as np
import collections
from datetime import datetime
from sensor_msgs.msg import Image, JointState
from shadow_rm_robot.realman_arm import RmArm
from message_filters import Subscriber, ApproximateTimeSynchronizer
class DataSynchronizer:
def __init__(self, config_path="config"):
rospy.init_node("synchronizer", anonymous=True)
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"]
self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"]
self.arm_axis = config["robot_env"]["arm_axis"]
self.camera_names = config["camera_names"]
# 创建订阅者
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
self.camera_bottom_sub = Subscriber(
config["ros_topics"]["camera_bottom"], Image
)
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
self.left_slave_arm_sub = Subscriber(
config["ros_topics"]["left_slave_arm_sub"], JointState
)
self.right_slave_arm_sub = Subscriber(
config["ros_topics"]["right_slave_arm_sub"], JointState
)
self.left_slave_arm_pub = rospy.Publisher(
config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1
)
self.right_slave_arm_pub = rospy.Publisher(
config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1
)
# 创建同步器
self.ats = ApproximateTimeSynchronizer(
[
self.camera_left_sub,
self.camera_right_sub,
self.camera_bottom_sub,
self.camera_head_sub,
self.left_slave_arm_sub,
self.right_slave_arm_sub,
],
queue_size=1,
slop=0.1,
)
self.ats.registerCallback(self.callback)
self.ts = None
self.is_frist_step = True
def callback(
self,
camera_left_img,
camera_right_img,
camera_bottom_img,
camera_head_img,
left_slave_arm,
right_slave_arm,
):
# 将ROS图像消息转换为NumPy数组
camera_left_np_img = np.frombuffer(
camera_left_img.data, dtype=np.uint8
).reshape(camera_left_img.height, camera_left_img.width, -1)
camera_right_np_img = np.frombuffer(
camera_right_img.data, dtype=np.uint8
).reshape(camera_right_img.height, camera_right_img.width, -1)
camera_bottom_np_img = np.frombuffer(
camera_bottom_img.data, dtype=np.uint8
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
camera_head_np_img = np.frombuffer(
camera_head_img.data, dtype=np.uint8
).reshape(camera_head_img.height, camera_head_img.width, -1)
left_slave_arm_angle = left_slave_arm.position
left_slave_arm_velocity = left_slave_arm.velocity
left_slave_arm_force = left_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
right_slave_arm_angle = right_slave_arm.position
right_slave_arm_velocity = right_slave_arm.velocity
right_slave_arm_force = right_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
# 收集数据
obs = collections.OrderedDict(
{
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
"qvel": np.concatenate(
[left_slave_arm_velocity, right_slave_arm_velocity]
),
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
"images": {
self.camera_names[0]: camera_head_np_img,
self.camera_names[1]: camera_bottom_np_img,
self.camera_names[2]: camera_left_np_img,
self.camera_names[3]: camera_right_np_img,
},
}
)
self.ts = dm_env.TimeStep(
step_type=(
dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID
),
reward=0.0,
discount=1.0,
observation=obs,
)
def reset(self):
left_joint_state = JointState()
left_joint_state.header.stamp = rospy.Time.now()
left_joint_state.name = ["joint_j"]
left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1]
right_joint_state = JointState()
right_joint_state.header.stamp = rospy.Time.now()
right_joint_state.name = ["joint_j"]
right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1]
self.left_slave_arm_pub.publish(left_joint_state)
self.right_slave_arm_pub.publish(right_joint_state)
while self.ts is None:
time.sleep(0.002)
return self.ts
def step(self, target_angle):
self.is_frist_step = False
left_joint_state = JointState()
left_joint_state.header.stamp = rospy.Time.now()
left_joint_state.name = ["joint_canfd"]
left_joint_state.position = target_angle[0 : self.arm_axis + 1]
# print("left_joint_state: ", left_joint_state)
right_joint_state = JointState()
right_joint_state.header.stamp = rospy.Time.now()
right_joint_state.name = ["joint_canfd"]
right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2]
# print("right_joint_state: ", right_joint_state)
self.left_slave_arm_pub.publish(left_joint_state)
self.right_slave_arm_pub.publish(right_joint_state)
# time.sleep(0.013)
return self.ts
def run(self):
rospy.loginfo("Starting ROS spin")
data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle])
self.reset()
# print("data: ", data)
while not rospy.is_shutdown():
self.step(data)
rospy.sleep(0.010)
if __name__ == "__main__":
synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml")
start_time = time.time()
synchronizer.run()

View File

@@ -0,0 +1,280 @@
#!/usr/bin/env python3
import os
import time
import h5py
import yaml
import rospy
import dm_env
import numpy as np
import collections
from datetime import datetime
from std_msgs.msg import Int32MultiArray
from sensor_msgs.msg import Image, JointState
from message_filters import Subscriber, ApproximateTimeSynchronizer
class DataCollector:
def __init__(
self,
dataset_dir,
dataset_name,
max_timesteps,
camera_names,
state_dim,
overwrite=False,
):
self.dataset_dir = dataset_dir
self.dataset_name = dataset_name
self.max_timesteps = max_timesteps
self.camera_names = camera_names
self.state_dim = state_dim
self.overwrite = overwrite
self.init_dict()
self.create_dataset_dir()
def init_dict(self):
self.data_dict = {
"/observations/qpos": [],
"/observations/qvel": [],
"/observations/effort": [],
"/action": [],
}
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"] = []
def create_dataset_dir(self):
# 按照年月日创建目录
date_str = datetime.now().strftime("%Y%m%d")
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
def create_file(self):
# 检查数据集名称是否存在,如果存在则递增名称
counter = 0
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
if not self.overwrite:
while os.path.exists(dataset_path + ".hdf5"):
dataset_path = os.path.join(
self.dataset_dir, f"{self.dataset_name}_{counter}"
)
counter += 1
self.dataset_path = dataset_path
def collect_data(self, ts, action):
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
self.data_dict["/observations/effort"].append(ts.observation["effort"])
self.data_dict["/action"].append(action)
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"].append(
ts.observation["images"][cam_name]
)
def save_data(self):
self.create_file()
t0 = time.time()
# 保存数据
with h5py.File(
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
) as root:
root.attrs["sim"] = False
obs = root.create_group("observations")
image = obs.create_group("images")
for cam_name in self.camera_names:
_ = image.create_dataset(
cam_name,
(self.max_timesteps, 480, 640, 3),
dtype="uint8",
chunks=(1, 480, 640, 3),
)
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
for name, array in self.data_dict.items():
root[name][...] = array
print(f"Saving: {time.time() - t0:.1f} secs")
return True
class DataSynchronizer:
def __init__(self, config_path="config"):
rospy.init_node("synchronizer", anonymous=True)
rospy.loginfo("ROS node initialized")
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.arm_axis = config["arm_axis"]
# 创建订阅者
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
self.camera_bottom_sub = Subscriber(
config["ros_topics"]["camera_bottom"], Image
)
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
self.left_master_arm_sub = Subscriber(
config["ros_topics"]["left_master_arm"], JointState
)
self.left_slave_arm_sub = Subscriber(
config["ros_topics"]["left_slave_arm"], JointState
)
self.right_master_arm_sub = Subscriber(
config["ros_topics"]["right_master_arm"], JointState
)
self.right_slave_arm_sub = Subscriber(
config["ros_topics"]["right_slave_arm"], JointState
)
self.left_aloha_state_pub = rospy.Subscriber(
config["ros_topics"]["left_aloha_state"],
Int32MultiArray,
self.aloha_state_callback,
)
rospy.loginfo("Subscribers created")
self.camera_names = config["camera_names"]
# 创建同步器
self.ats = ApproximateTimeSynchronizer(
[
self.camera_left_sub,
self.camera_right_sub,
self.camera_bottom_sub,
self.camera_head_sub,
self.left_master_arm_sub,
self.left_slave_arm_sub,
self.right_master_arm_sub,
self.right_slave_arm_sub,
],
queue_size=1,
slop=0.05,
)
self.ats.registerCallback(self.callback)
rospy.loginfo("Time synchronizer created and callback registered")
self.data_collector = DataCollector(
dataset_dir=config["dataset_dir"],
dataset_name=config["dataset_name"],
max_timesteps=config["max_timesteps"],
camera_names=config["camera_names"],
state_dim=config["state_dim"],
overwrite=config["overwrite"],
)
self.timesteps_collected = 0
self.begin_collect = False
self.last_time = None
def callback(
self,
camera_left_img,
camera_right_img,
camera_bottom_img,
camera_head_img,
left_master_arm,
left_slave_arm,
right_master_arm,
right_slave_arm,
):
if self.begin_collect:
self.timesteps_collected += 1
rospy.loginfo(
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
)
else:
self.timesteps_collected = 0
return
if self.timesteps_collected == 0:
return
current_time = time.time()
if self.last_time is not None:
frequency = 1.0 / (current_time - self.last_time)
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
self.last_time = current_time
# 将ROS图像消息转换为NumPy数组
camera_left_np_img = np.frombuffer(
camera_left_img.data, dtype=np.uint8
).reshape(camera_left_img.height, camera_left_img.width, -1)
camera_right_np_img = np.frombuffer(
camera_right_img.data, dtype=np.uint8
).reshape(camera_right_img.height, camera_right_img.width, -1)
camera_bottom_np_img = np.frombuffer(
camera_bottom_img.data, dtype=np.uint8
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
camera_head_np_img = np.frombuffer(
camera_head_img.data, dtype=np.uint8
).reshape(camera_head_img.height, camera_head_img.width, -1)
# 提取臂的角度,速度,力
left_master_arm_angle = left_master_arm.position
# left_master_arm_velocity = left_master_arm.velocity
# left_master_arm_force = left_master_arm.effort
left_slave_arm_angle = left_slave_arm.position
left_slave_arm_velocity = left_slave_arm.velocity
left_slave_arm_force = left_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
right_master_arm_angle = right_master_arm.position
# right_master_arm_velocity = right_master_arm.velocity
# right_master_arm_force = right_master_arm.effort
right_slave_arm_angle = right_slave_arm.position
right_slave_arm_velocity = right_slave_arm.velocity
right_slave_arm_force = right_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
# 收集数据
obs = collections.OrderedDict(
{
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
"qvel": np.concatenate(
[left_slave_arm_velocity, right_slave_arm_velocity]
),
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
"images": {
self.camera_names[0]: camera_head_np_img,
self.camera_names[1]: camera_bottom_np_img,
self.camera_names[2]: camera_left_np_img,
self.camera_names[3]: camera_right_np_img,
},
}
)
print(self.camera_names[0])
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
)
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
self.data_collector.collect_data(ts, action)
# 检查是否收集了足够的数据
if self.timesteps_collected >= self.data_collector.max_timesteps:
self.data_collector.save_data()
rospy.loginfo("Data collection complete")
self.data_collector.init_dict()
self.begin_collect = False
self.timesteps_collected = 0
def aloha_state_callback(self, data):
if not self.begin_collect:
self.aloha_state = data.data
print(self.aloha_state[0], self.aloha_state[1])
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
self.begin_collect = True
def run(self):
rospy.loginfo("Starting ROS spin")
rospy.spin()
if __name__ == "__main__":
synchronizer = DataSynchronizer(
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
)
synchronizer.run()

View File

@@ -0,0 +1,63 @@
<launch>
<!-- 左从臂节点 -->
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
<param name="aloha_state_topic" value="/left_slave_arm_aloha_state" type= "string"/>
<param name="hz" value="120" type= "int"/>
</node>
<!-- 右从臂节点 -->
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
<param name="aloha_state_topic" value="/right_slave_arm_aloha_state" type= "string"/>
<param name="hz" value="120" type= "int"/>
</node>
<!-- 左主臂节点 -->
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右主臂节点 -->
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右臂相机节点 -->
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="151222072576" type= "string"/>
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左臂相机节点 -->
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="150622070125" type= "string"/>
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 顶部相机节点 -->
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="241122071186" type= "string"/>
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 底部相机节点 -->
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="152122078546" type= "string"/>
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
</launch>

View File

@@ -0,0 +1,49 @@
<launch>
<!-- 左从臂节点 -->
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
<param name="joint_actions_topic" value="/left_slave_arm_joint_actions" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右从臂节点 -->
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
<param name="joint_actions_topic" value="/right_slave_arm_joint_actions" type= "string"/>
<param name="hz" value="90" type= "int"/>
</node>
<!-- 右臂相机节点 -->
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="151222072576" type= "string"/>
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左臂相机节点 -->
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="150622070125" type= "string"/>
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 顶部相机节点 -->
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="241122071186" type= "string"/>
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 底部相机节点 -->
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="152122078546" type= "string"/>
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
</launch>

View File

@@ -0,0 +1,61 @@
<launch>
<!-- 左从臂节点 -->
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 右从臂节点 -->
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左主臂节点 -->
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_left_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 右主臂节点 -->
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_right_arm.yaml" type= "string"/>
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 右臂相机节点 -->
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="216322070299" type= "string"/>
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 左臂相机节点 -->
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="216322074992" type= "string"/>
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 顶部相机节点 -->
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="215322076086" type= "string"/>
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
<!-- 底部相机节点 -->
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
<param name="serial_number" value="215222074360" type= "string"/>
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
<param name="hz" value="50" type= "int"/>
</node>
</launch>

View File

@@ -0,0 +1,284 @@
#!/usr/bin/env python3
import os
import cv2
import time
import h5py
import yaml
import json
import rospy
import dm_env
import socket
import numpy as np
import collections
from datetime import datetime
from std_msgs.msg import Int32MultiArray
from sensor_msgs.msg import Image, JointState
from message_filters import Subscriber, ApproximateTimeSynchronizer
class DataCollector:
def __init__(
self,
dataset_dir,
dataset_name,
max_timesteps,
camera_names,
state_dim,
overwrite=False,
):
self.dataset_dir = dataset_dir
self.dataset_name = dataset_name
self.max_timesteps = max_timesteps
self.camera_names = camera_names
self.state_dim = state_dim
self.overwrite = overwrite
self.init_dict()
self.create_dataset_dir()
def init_dict(self):
self.data_dict = {
"/observations/qpos": [],
"/observations/qvel": [],
"/observations/effort": [],
"/action": [],
}
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"] = []
def create_dataset_dir(self):
# 按照年月日创建目录
date_str = datetime.now().strftime("%Y%m%d")
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)
def create_file(self):
# 检查数据集名称是否存在,如果存在则递增名称
counter = 0
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
if not self.overwrite:
while os.path.exists(dataset_path + ".hdf5"):
dataset_path = os.path.join(
self.dataset_dir, f"{self.dataset_name}_{counter}"
)
counter += 1
self.dataset_path = dataset_path
def collect_data(self, ts, action):
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
self.data_dict["/observations/effort"].append(ts.observation["effort"])
self.data_dict["/action"].append(action)
for cam_name in self.camera_names:
self.data_dict[f"/observations/images/{cam_name}"].append(
ts.observation["images"][cam_name]
)
def save_data(self):
self.create_file()
t0 = time.time()
# 保存数据
with h5py.File(
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
) as root:
root.attrs["sim"] = False
obs = root.create_group("observations")
image = obs.create_group("images")
for cam_name in self.camera_names:
_ = image.create_dataset(
cam_name,
(self.max_timesteps, 480, 640, 3),
dtype="uint8",
chunks=(1, 480, 640, 3),
)
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
for name, array in self.data_dict.items():
root[name][...] = array
print(f"Saving: {time.time() - t0:.1f} secs")
return True
class DataSynchronizer:
def __init__(self, config_path="config"):
rospy.init_node("synchronizer", anonymous=True)
rospy.loginfo("ROS node initialized")
with open(config_path, "r") as file:
config = yaml.safe_load(file)
self.arm_axis = config["arm_axis"]
# 创建订阅者
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
self.camera_bottom_sub = Subscriber(
config["ros_topics"]["camera_bottom"], Image
)
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
self.left_master_arm_sub = Subscriber(
config["ros_topics"]["left_master_arm"], JointState
)
self.left_slave_arm_sub = Subscriber(
config["ros_topics"]["left_slave_arm"], JointState
)
self.right_master_arm_sub = Subscriber(
config["ros_topics"]["right_master_arm"], JointState
)
self.right_slave_arm_sub = Subscriber(
config["ros_topics"]["right_slave_arm"], JointState
)
self.left_aloha_state_pub = rospy.Subscriber(
config["ros_topics"]["left_aloha_state"],
Int32MultiArray,
self.aloha_state_callback,
)
rospy.loginfo("Subscribers created")
# 创建同步器
self.ats = ApproximateTimeSynchronizer(
[
self.camera_left_sub,
self.camera_right_sub,
self.camera_bottom_sub,
self.camera_head_sub,
self.left_master_arm_sub,
self.left_slave_arm_sub,
self.right_master_arm_sub,
self.right_slave_arm_sub,
],
queue_size=1,
slop=0.05,
)
self.ats.registerCallback(self.callback)
rospy.loginfo("Time synchronizer created and callback registered")
self.data_collector = DataCollector(
dataset_dir=config["dataset_dir"],
dataset_name=config["dataset_name"],
max_timesteps=config["max_timesteps"],
camera_names=config["camera_names"],
state_dim=config["state_dim"],
overwrite=config["overwrite"],
)
self.timesteps_collected = 0
self.begin_collect = False
self.last_time = None
def callback(
self,
camera_left_img,
camera_right_img,
camera_bottom_img,
camera_head_img,
left_master_arm,
left_slave_arm,
right_master_arm,
right_slave_arm,
):
if self.begin_collect:
self.timesteps_collected += 1
rospy.loginfo(
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
)
else:
self.timesteps_collected = 0
return
if self.timesteps_collected == 0:
return
current_time = time.time()
if self.last_time is not None:
frequency = 1.0 / (current_time - self.last_time)
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
self.last_time = current_time
# 将ROS图像消息转换为NumPy数组
camera_left_np_img = np.frombuffer(
camera_left_img.data, dtype=np.uint8
).reshape(camera_left_img.height, camera_left_img.width, -1)
camera_right_np_img = np.frombuffer(
camera_right_img.data, dtype=np.uint8
).reshape(camera_right_img.height, camera_right_img.width, -1)
camera_bottom_np_img = np.frombuffer(
camera_bottom_img.data, dtype=np.uint8
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
camera_head_np_img = np.frombuffer(
camera_head_img.data, dtype=np.uint8
).reshape(camera_head_img.height, camera_head_img.width, -1)
# 提取臂的角度,速度,力
left_master_arm_angle = left_master_arm.position
# left_master_arm_velocity = left_master_arm.velocity
# left_master_arm_force = left_master_arm.effort
left_slave_arm_angle = left_slave_arm.position
left_slave_arm_velocity = left_slave_arm.velocity
left_slave_arm_force = left_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
right_master_arm_angle = right_master_arm.position
# right_master_arm_velocity = right_master_arm.velocity
# right_master_arm_force = right_master_arm.effort
right_slave_arm_angle = right_slave_arm.position
right_slave_arm_velocity = right_slave_arm.velocity
right_slave_arm_force = right_slave_arm.effort
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
# 收集数据
obs = collections.OrderedDict(
{
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
"qvel": np.concatenate(
[left_slave_arm_velocity, right_slave_arm_velocity]
),
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
"images": {
"cam_front": camera_head_np_img,
"cam_low": camera_bottom_np_img,
"cam_left": camera_left_np_img,
"cam_right": camera_right_np_img,
},
}
)
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
)
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
self.data_collector.collect_data(ts, action)
# 检查是否收集了足够的数据
if self.timesteps_collected >= self.data_collector.max_timesteps:
self.data_collector.save_data()
rospy.loginfo("Data collection complete")
self.data_collector.init_dict()
self.begin_collect = False
self.timesteps_collected = 0
def aloha_state_callback(self, data):
if not self.begin_collect:
self.aloha_state = data.data
print(self.aloha_state[0], self.aloha_state[1])
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
self.begin_collect = True
def run(self):
rospy.loginfo("Starting ROS spin")
rospy.spin()
if __name__ == "__main__":
synchronizer = DataSynchronizer(
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
)
synchronizer.run()

View File

@@ -0,0 +1,31 @@
<?xml version="1.0"?>
<package format="2">
<name>shadow_rm_aloha</name>
<version>0.0.1</version>
<description>The shadow_rm_aloha package</description>
<maintainer email="your_email@example.com">Your Name</maintainer>
<license>TODO</license>
<buildtool_depend>catkin</buildtool_depend>
<build_depend>rospy</build_depend>
<build_depend>sensor_msgs</build_depend>
<build_depend>std_msgs</build_depend>
<build_depend>cv_bridge</build_depend>
<build_depend>image_transport</build_depend>
<build_depend>message_generation</build_depend>
<build_depend>message_runtime</build_depend>
<exec_depend>rospy</exec_depend>
<exec_depend>sensor_msgs</exec_depend>
<exec_depend>std_msgs</exec_depend>
<exec_depend>cv_bridge</exec_depend>
<exec_depend>image_transport</exec_depend>
<exec_depend>message_runtime</exec_depend>
<export>
</export>
</package>

View File

@@ -0,0 +1,5 @@
# GetArmStatus.srv
---
sensor_msgs/JointState joint_status

View File

@@ -0,0 +1,4 @@
# GetImage.srv
---
bool success
sensor_msgs/Image image

View File

@@ -0,0 +1,4 @@
# MoveArm.srv
float32[] joint_angle
---
bool success

View File

@@ -0,0 +1 @@
__version__ = '0.1.0'

View File

@@ -0,0 +1,49 @@
import multiprocessing as mp
import time
def collect_data(arm_id, cam_id, data_queue, lock):
while True:
# 模拟数据采集
arm_data = f"Arm {arm_id} data"
cam_data = f"Cam {cam_id} data"
# 获取当前时间戳
timestamp = time.time()
# 将数据放入队列
with lock:
data_queue.put((timestamp, arm_data, cam_data))
# 模拟高帧率
time.sleep(0.01)
def main():
num_arms = 4
num_cams = 4
# 创建队列和锁
data_queue = mp.Queue()
lock = mp.Lock()
# 创建进程
processes = []
for i in range(num_arms):
p = mp.Process(target=collect_data, args=(i, i, data_queue, lock))
processes.append(p)
p.start()
# 主进程处理数据
try:
while True:
if not data_queue.empty():
with lock:
timestamp, arm_data, cam_data = data_queue.get()
print(f"Timestamp: {timestamp}, {arm_data}, {cam_data}")
except KeyboardInterrupt:
for p in processes:
p.terminate()
for p in processes:
p.join()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,38 @@
import os
import shutil
from datetime import datetime
from shadow_rm_aloha.data_sub_process.aloha_data_synchronizer import DataCollector
def test_create_dataset_dir():
# 设置测试参数
dataset_dir = './test_data/dataset'
dataset_name = 'test_episode'
max_timesteps = 100
camera_names = ['cam1', 'cam2']
overwrite = False
# 清理旧的测试数据
if os.path.exists(dataset_dir):
shutil.rmtree(dataset_dir)
# 创建 DataCollector 实例并调用 create_dataset_dir
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, overwrite)
# 检查目录是否按预期创建
date_str = datetime.now().strftime("%Y%m%d")
expected_dir = os.path.join(dataset_dir, date_str)
assert os.path.exists(expected_dir), f"Expected directory {expected_dir} does not exist."
# 检查文件名是否按预期递增
expected_file = os.path.join(expected_dir, dataset_name + '.hdf5')
assert collector.dataset_path == expected_file, f"Expected file path {expected_file}, but got {collector.dataset_path}"
# 再次调用 create_dataset_dir检查文件名是否递增
# collector.create_dataset_dir()
expected_file_incremented = os.path.join(expected_dir, dataset_name + '_1.hdf5')
assert collector.dataset_path == expected_file_incremented, f"Expected file path {expected_file_incremented}, but got {collector.dataset_path}"
print("All tests passed.")
if __name__ == '__main__':
test_create_dataset_dir()

View File

@@ -0,0 +1,105 @@
import multiprocessing
import time
import random
import socket
import json
import logging
# 设置日志级别
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class test_udp():
def __init__(self):
arm_ip = '192.168.1.19'
arm_port = 8080
self.arm =socket.socket()
self.arm.connect((arm_ip, arm_port))
set_udp = {"command":"set_realtime_push","cycle":1,"enable":True,"port":8090,"ip":"192.168.1.101","custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
self.arm.send(json.dumps(set_udp).encode('utf-8'))
state = self.arm.recv(1024)
logging.info(f"Send data to {arm_ip}:{arm_port}: {state}")
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# 设置套接字选项,允许端口复用
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
local_ip = "192.168.1.101"
local_port = 8090
self.udp_socket.bind((local_ip, local_port))
self.BUFFER_SIZE = 1024
def set_udp(self):
while True:
start_time = time.time()
data, addr = self.udp_socket.recvfrom(self.BUFFER_SIZE)
# 将接收到的UDP数据解码并解析为JSON
data = json.loads(data.decode('utf-8'))
end_time = time.time()
print(f"Received data {data}")
udp_socket.close()
def collect_arm_data(arm_id, queue, event):
while True:
data = f"Arm {arm_id} data {random.random()}"
queue.put((arm_id, data))
event.set()
time.sleep(1)
def collect_camera_data(camera_id, queue, event):
while True:
data = f"Camera {camera_id} data {random.random()}"
queue.put((camera_id, data))
event.set()
time.sleep(1)
def main():
arm_queues = [multiprocessing.Queue() for _ in range(4)]
camera_queues = [multiprocessing.Queue() for _ in range(4)]
arm_events = [multiprocessing.Event() for _ in range(4)]
camera_events = [multiprocessing.Event() for _ in range(4)]
arm_processes = [multiprocessing.Process(target=collect_arm_data, args=(i, arm_queues[i], arm_events[i])) for i in range(4)]
camera_processes = [multiprocessing.Process(target=collect_camera_data, args=(i, camera_queues[i], camera_events[i])) for i in range(4)]
for p in arm_processes + camera_processes:
p.start()
try:
while True:
for event in arm_events + camera_events:
event.wait()
for i in range(4):
if not arm_queues[i].empty():
arm_id, arm_data = arm_queues[i].get()
print(f"Received from Arm {arm_id}: {arm_data}")
arm_events[i].clear()
if not camera_queues[i].empty():
camera_id, camera_data = camera_queues[i].get()
print(f"Received from Camera {camera_id}: {camera_data}")
camera_events[i].clear()
time.sleep(0.1)
except KeyboardInterrupt:
for p in arm_processes + camera_processes:
p.terminate()
if __name__ == "__main__":
main()
# if __name__ == "__main__":
# test_udp = test_udp()
# test_udp.set_udp()