forked from tangger/lerobot
add realman shadow src
This commit is contained in:
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
10
realman_src/realman_aloha/shadow_rm_aloha/.gitignore
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
__pycache__/
|
||||
build/
|
||||
devel/
|
||||
dist/
|
||||
data/
|
||||
.catkin_workspace
|
||||
*.pyc
|
||||
*.pyo
|
||||
*.pt
|
||||
.vscode/
|
||||
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
3
realman_src/realman_aloha/shadow_rm_aloha/.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
1
realman_src/realman_aloha/shadow_rm_aloha/.idea/.name
generated
Normal file
@@ -0,0 +1 @@
|
||||
aloha_data_synchronizer.py
|
||||
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
17
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,17 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="4">
|
||||
<item index="0" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="1" class="java.lang.String" itemvalue="thop" />
|
||||
<item index="2" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="3" class="java.lang.String" itemvalue="torchvision" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
realman_src/realman_aloha/shadow_rm_aloha/.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
7
realman_src/realman_aloha/shadow_rm_aloha/.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.11 (随箱软件)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.11 (随箱软件)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
8
realman_src/realman_aloha/shadow_rm_aloha/.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" filepath="$PROJECT_DIR$/.idea/shadow_rm_aloha.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
12
realman_src/realman_aloha/shadow_rm_aloha/.idea/shadow_rm_aloha.iml
generated
Normal file
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
0
realman_src/realman_aloha/shadow_rm_aloha/README.md
Normal file
@@ -0,0 +1,38 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset'
|
||||
dataset_name: 'episode'
|
||||
max_timesteps: 500
|
||||
state_dim: 14
|
||||
overwrite: False
|
||||
arm_axis: 6
|
||||
camera_names:
|
||||
- 'cam_high'
|
||||
- 'cam_low'
|
||||
- 'cam_left'
|
||||
- 'cam_right'
|
||||
ros_topics:
|
||||
camera_left: '/camera_left/rgb/image_raw'
|
||||
camera_right: '/camera_right/rgb/image_raw'
|
||||
camera_bottom: '/camera_bottom/rgb/image_raw'
|
||||
camera_head: '/camera_head/rgb/image_raw'
|
||||
left_master_arm: '/left_master_arm_joint_states'
|
||||
left_slave_arm: '/left_slave_arm_joint_states'
|
||||
right_master_arm: '/right_master_arm_joint_states'
|
||||
right_slave_arm: '/right_slave_arm_joint_states'
|
||||
left_aloha_state: '/left_slave_arm_aloha_state'
|
||||
right_aloha_state: '/right_slave_arm_aloha_state'
|
||||
robot_env: {
|
||||
# TODO change the path to the correct one
|
||||
rm_left_arm: '/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml',
|
||||
rm_right_arm: '/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml',
|
||||
arm_axis: 6,
|
||||
head_camera: '241122071186',
|
||||
bottom_camera: '152122078546',
|
||||
left_camera: '150622070125',
|
||||
right_camera: '151222072576',
|
||||
init_left_arm_angle: [7.235, 31.816, 51.237, 2.463, 91.054, 12.04, 0.0],
|
||||
init_right_arm_angle: [-6.155, 33.925, 62.137, -1.672, 87.892, -3.868, 0.0]
|
||||
# init_left_arm_angle: [6.681, 38.496, 66.093, -1.141, 74.529, 3.076, 0.0],
|
||||
# init_right_arm_angle: [-4.79, 37.062, 72.393, -0.477, 68.593, -9.526, 0.0]
|
||||
# init_left_arm_angle: [6.45, 66.093, 2.9, 20.919, -1.491, 100.756, 18.808, 0.617],
|
||||
# init_right_arm_angle: [166.953, -33.575, -163.917, 73.3, -9.581, 69.51, 0.876]
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.18"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8089
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,9 @@
|
||||
arm_ip: "192.168.1.19"
|
||||
arm_port: 8080
|
||||
arm_axis: 6
|
||||
local_ip: "192.168.1.101"
|
||||
local_port: 8090
|
||||
# arm_ki: [7, 7, 7, 3, 3, 3, 3] # rm75
|
||||
arm_ki: [7, 7, 7, 3, 3, 3] # rm65
|
||||
get_vel: True
|
||||
get_torque: True
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB1
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,4 @@
|
||||
port: /dev/ttyUSB0
|
||||
baudrate: 460800
|
||||
hex_data: "55 AA 02 00 00 67"
|
||||
arm_axis: 6
|
||||
@@ -0,0 +1,6 @@
|
||||
dataset_dir: '/home/rm/code/shadow_rm_aloha/data/dataset/20250102'
|
||||
dataset_name: 'episode'
|
||||
episode_idx: 1
|
||||
FPS: 30
|
||||
# joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate", "J7"] # 7 joints
|
||||
joint_names: ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"] # 6 joints
|
||||
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
39
realman_src/realman_aloha/shadow_rm_aloha/pyproject.toml
Normal file
@@ -0,0 +1,39 @@
|
||||
[tool.poetry]
|
||||
name = "shadow_rm_aloha"
|
||||
version = "0.1.1"
|
||||
description = "aloha package, use D435 and Realman robot arm to build aloha to collect data"
|
||||
readme = "README.md"
|
||||
authors = ["Shadow <qiuchengzhan@gmail.com>"]
|
||||
license = "MIT"
|
||||
# include = ["realman_vision/pytransform/_pytransform.so",]
|
||||
classifiers = [
|
||||
"Operating System :: POSIX :: Linux amd64",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.10"
|
||||
matplotlib = ">=3.9.2"
|
||||
h5py = ">=3.12.1"
|
||||
# rospy = ">=1.17.0"
|
||||
# shadow_rm_robot = { git = "https://github.com/Shadow2223/shadow_rm_robot.git", branch = "main" }
|
||||
# shadow_camera = { git = "https://github.com/Shadow2223/shadow_camera.git", branch = "main" }
|
||||
|
||||
|
||||
|
||||
[tool.poetry.dev-dependencies] # 列出开发时所需的依赖项,比如测试、文档生成等工具。
|
||||
pytest = ">=8.3"
|
||||
black = ">=24.10.0"
|
||||
|
||||
|
||||
|
||||
[tool.poetry.plugins."scripts"] # 定义命令行脚本,使得用户可以通过命令行运行指定的函数。
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.8.4"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
@@ -0,0 +1,42 @@
|
||||
cmake_minimum_required(VERSION 3.0.2)
|
||||
project(shadow_rm_aloha)
|
||||
|
||||
find_package(catkin REQUIRED COMPONENTS
|
||||
rospy
|
||||
sensor_msgs
|
||||
cv_bridge
|
||||
image_transport
|
||||
std_msgs
|
||||
message_generation
|
||||
)
|
||||
|
||||
add_service_files(
|
||||
FILES
|
||||
GetArmStatus.srv
|
||||
GetImage.srv
|
||||
MoveArm.srv
|
||||
)
|
||||
|
||||
generate_messages(
|
||||
DEPENDENCIES
|
||||
sensor_msgs
|
||||
std_msgs
|
||||
)
|
||||
|
||||
catkin_package(
|
||||
CATKIN_DEPENDS message_runtime rospy std_msgs
|
||||
)
|
||||
|
||||
include_directories(
|
||||
${catkin_INCLUDE_DIRS}
|
||||
)
|
||||
|
||||
install(PROGRAMS
|
||||
arm_node/slave_arm_publisher.py
|
||||
arm_node/master_arm_publisher.py
|
||||
arm_node/slave_arm_pub_sub.py
|
||||
camera_node/camera_publisher.py
|
||||
data_sub_process/aloha_data_synchronizer.py
|
||||
data_sub_process/aloha_data_collect.py
|
||||
DESTINATION ${CATKIN_PACKAGE_BIN_DESTINATION}
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.servo_robotic_arm import ServoArm
|
||||
from sensor_msgs.msg import JointState
|
||||
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class MasterArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("master_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config","config/servo_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
self.joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
|
||||
self.arm = ServoArm(arm_config)
|
||||
self.publisher = rospy.Publisher(self.joint_states_topic, JointState, queue_size=1)
|
||||
self.rate = rospy.Rate(hz) # 30 Hz
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
joint_pos = self.arm.get_joint_actions()
|
||||
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = list(joint_pos.keys())
|
||||
joint_state.position = list(joint_pos.values())
|
||||
joint_state.velocity = [0.0] * len(joint_pos) # 速度(可根据需要修改)
|
||||
joint_state.effort = [0.0] * len(joint_pos) # 力矩(可根据需要修改)
|
||||
|
||||
# rospy.loginfo(f"{self.joint_states_topic}: {joint_state}")
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = MasterArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from sensor_msgs.msg import JointState
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class SlaveArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
joint_actions_topic = rospy.get_param("~joint_actions_topic", "/joint_actions")
|
||||
self.arm = RmArm(arm_config)
|
||||
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||
self.subscriber = rospy.Subscriber(joint_actions_topic, JointState, self.callback)
|
||||
self.rate = rospy.Rate(hz)
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
data = self.arm.get_integrate_data()
|
||||
# data = self.arm.get_arm_data()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6"]
|
||||
# joint_state.position = data["joint_angle"]
|
||||
joint_state.position = data['arm_angle']
|
||||
|
||||
# joint_state.position = list(data["arm_angle"])
|
||||
# joint_state.velocity = list(data["arm_velocity"])
|
||||
# joint_state.effort = list(data["arm_torque"])
|
||||
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
def callback(self, data):
|
||||
# rospy.loginfo(f"Received joint_states_topic: {data}")
|
||||
start_time = rospy.Time.now()
|
||||
if data is None:
|
||||
return
|
||||
if data.name == ["joint_canfd"]:
|
||||
self.arm.set_joint_canfd_position(data.position[0:6])
|
||||
elif data.name == ["joint_j"]:
|
||||
self.arm.set_joint_position(data.position[0:6])
|
||||
|
||||
# self.arm.set_gripper_position(data.position[6])
|
||||
end_time = rospy.Time.now()
|
||||
time_cost_ms = (end_time - start_time).to_sec() * 1000
|
||||
rospy.loginfo(f"Time cost: {data.name},{time_cost_ms}")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = SlaveArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,44 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
import logging
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from sensor_msgs.msg import JointState
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
# 配置日志记录
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
class SlaveArmPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node("slave_arm_publisher", anonymous=True)
|
||||
arm_config = rospy.get_param("~arm_config", default="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml")
|
||||
hz = rospy.get_param("~hz", 250)
|
||||
joint_states_topic = rospy.get_param("~joint_states_topic", "/joint_states")
|
||||
aloha_state_topic = rospy.get_param("~aloha_state_topic", "/aloha_state")
|
||||
self.arm = RmArm(arm_config)
|
||||
self.publisher = rospy.Publisher(joint_states_topic, JointState, queue_size=1)
|
||||
self.aloha_state_pub = rospy.Publisher(aloha_state_topic, Int32MultiArray, queue_size=1)
|
||||
self.rate = rospy.Rate(hz)
|
||||
|
||||
def publish_joint_states(self):
|
||||
while not rospy.is_shutdown():
|
||||
joint_state = JointState()
|
||||
data = self.arm.get_integrate_data()
|
||||
joint_state.header.stamp = rospy.Time.now()
|
||||
joint_state.name = ["joint1", "joint2", "joint3", "joint4", "joint5", "joint6", "joint7"]
|
||||
joint_state.position = list(data["arm_angle"])
|
||||
joint_state.velocity = list(data["arm_velocity"])
|
||||
joint_state.effort = list(data["arm_torque"])
|
||||
# rospy.loginfo(f"joint_states_topic: {joint_state}")
|
||||
self.aloha_state_pub.publish(Int32MultiArray(data=data["aloha_state"].values()))
|
||||
self.publisher.publish(joint_state)
|
||||
self.rate.sleep()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
arm_publisher = SlaveArmPublisher()
|
||||
arm_publisher.publish_joint_states()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,69 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import rospy
|
||||
from sensor_msgs.msg import Image
|
||||
from std_msgs.msg import Header
|
||||
import numpy as np
|
||||
from shadow_camera.realsense import RealSenseCamera
|
||||
|
||||
class CameraPublisher:
|
||||
def __init__(self):
|
||||
rospy.init_node('camera_publisher', anonymous=True)
|
||||
self.serial_number = rospy.get_param('~serial_number', None)
|
||||
hz = rospy.get_param('~hz', 30)
|
||||
rospy.loginfo(f"Serial number: {self.serial_number}")
|
||||
|
||||
self.rgb_topic = rospy.get_param('~rgb_topic', '/camera/rgb/image_raw')
|
||||
self.depth_topic = rospy.get_param('~depth_topic', '/camera/depth/image_raw')
|
||||
rospy.loginfo(f"RGB topic: {self.rgb_topic}")
|
||||
rospy.loginfo(f"Depth topic: {self.depth_topic}")
|
||||
|
||||
self.rgb_pub = rospy.Publisher(self.rgb_topic, Image, queue_size=10)
|
||||
# self.depth_pub = rospy.Publisher(self.depth_topic, Image, queue_size=10)
|
||||
self.rate = rospy.Rate(hz) # 30 Hz
|
||||
self.camera = RealSenseCamera(self.serial_number, False)
|
||||
|
||||
rospy.loginfo("Camera initialized")
|
||||
|
||||
def publish_images(self):
|
||||
self.camera.start_camera()
|
||||
rospy.loginfo("Camera started")
|
||||
while not rospy.is_shutdown():
|
||||
result = self.camera.read_frame(True, False, False, False)
|
||||
if result is None:
|
||||
rospy.logerr("Failed to read frame from camera")
|
||||
continue
|
||||
|
||||
color_image, depth_image, _, _ = result
|
||||
|
||||
if color_image is not None or depth_image is not None:
|
||||
rgb_msg = self.create_image_msg(color_image, "bgr8")
|
||||
# depth_msg = self.create_image_msg(depth_image, "mono16")
|
||||
|
||||
self.rgb_pub.publish(rgb_msg)
|
||||
# self.depth_pub.publish(depth_msg)
|
||||
# rospy.loginfo("Published RGB image")
|
||||
else:
|
||||
rospy.logwarn("Received None for color_image or depth_image")
|
||||
|
||||
self.rate.sleep()
|
||||
self.camera.stop_camera()
|
||||
rospy.loginfo("Camera stopped")
|
||||
|
||||
def create_image_msg(self, image, encoding):
|
||||
msg = Image()
|
||||
msg.header = Header()
|
||||
msg.header.stamp = rospy.Time.now()
|
||||
msg.height, msg.width = image.shape[:2]
|
||||
msg.encoding = encoding
|
||||
msg.is_bigendian = False
|
||||
msg.step = image.strides[0]
|
||||
msg.data = np.array(image).tobytes()
|
||||
return msg
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
camera_publisher = CameraPublisher()
|
||||
camera_publisher.publish_images()
|
||||
except rospy.ROSInterruptException:
|
||||
pass
|
||||
@@ -0,0 +1,112 @@
|
||||
import os
|
||||
import time
|
||||
import h5py
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataCollector:
|
||||
def __init__(self, dataset_dir, dataset_name, max_timesteps, camera_names, state_dim, overwrite=False):
|
||||
self.arm_axis = 7
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.data_dict = {
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/observations/effort': [],
|
||||
'/action': [],
|
||||
}
|
||||
for cam_name in camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'] = []
|
||||
|
||||
# 自动检测和创建数据集目录
|
||||
self.create_dataset_dir()
|
||||
self.timesteps_collected = 0
|
||||
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
self.data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
self.data_dict['/observations/effort'].append(ts.observation['effort'])
|
||||
self.data_dict['/action'].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
|
||||
def save_data(self):
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(self.dataset_path, mode='w', rdcc_nbytes=1024**2*2) as root:
|
||||
root.attrs['sim'] = False
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(cam_name, (self.max_timesteps, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3))
|
||||
_ = obs.create_dataset('qpos', (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset('qvel', (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset('effort', (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset('action', (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'Saving: {time.time() - t0:.1f} secs')
|
||||
return True
|
||||
|
||||
def load_hdf5(self, orign_path, file):
|
||||
self.dataset_path = os.path.join(self.dataset_dir, file)
|
||||
if not os.path.isfile(orign_path):
|
||||
logging.error(f'Dataset does not exist at {orign_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(orign_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
self.qvel = root['/observations/qvel'][()]
|
||||
self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
self.qpos[:, self.arm_axis] = self.action[:, self.arm_axis]
|
||||
self.qpos[:, self.arm_axis*2+1] = self.action[:, self.arm_axis*2+1]
|
||||
|
||||
self.data_dict['/observations/qpos'] = self.qpos
|
||||
self.data_dict['/observations/qvel'] = self.qvel
|
||||
self.data_dict['/observations/effort'] = self.effort
|
||||
self.data_dict['/action'] = self.action
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f'/observations/images/{cam_name}'] = self.image_dict[cam_name]
|
||||
return True
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
用于更改夹爪数据,将从臂夹爪数据更改为主笔夹爪数据
|
||||
|
||||
"""
|
||||
dataset_dir = '/home/wang/project/shadow_rm_aloha/data'
|
||||
orign_dir = '/home/wang/project/shadow_rm_aloha/data/dataset/20241128'
|
||||
dataset_name = 'test'
|
||||
max_timesteps = 300
|
||||
camera_names = ['cam_high','cam_low','cam_left','cam_right']
|
||||
state_dim = 16
|
||||
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||
for file in os.listdir(orign_dir):
|
||||
collector.__init__(dataset_dir, dataset_name, max_timesteps, camera_names, state_dim)
|
||||
orign_path = os.path.join(orign_dir, file)
|
||||
print(orign_path)
|
||||
collector.load_hdf5(orign_path, file)
|
||||
collector.save_data()
|
||||
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import h5py
|
||||
import yaml
|
||||
import logging
|
||||
import time
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataValidator:
|
||||
def __init__(self, config):
|
||||
self.dataset_dir = config['dataset_dir']
|
||||
self.episode_idx = config['episode_idx']
|
||||
self.joint_names = config['joint_names']
|
||||
self.dataset_name = f'episode_{self.episode_idx}'
|
||||
self.dataset_path = os.path.join(self.dataset_dir, self.dataset_name + '.hdf5')
|
||||
self.state_names = self.joint_names + ["gripper"]
|
||||
self.arm = RmArm('/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml')
|
||||
|
||||
def load_hdf5(self):
|
||||
if not os.path.isfile(self.dataset_path):
|
||||
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
# self.qvel = root['/observations/qvel'][()]
|
||||
# self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
def validate_data(self):
|
||||
# 验证位置数据
|
||||
if not self.qpos.shape[1] == 14:
|
||||
logging.error('qpos shape does not match expected number of joints')
|
||||
return False
|
||||
|
||||
logging.info('Data validation passed')
|
||||
return True
|
||||
|
||||
def control_robot(self):
|
||||
self.arm.set_joint_position(self.qpos[0][0:6])
|
||||
for qpos in self.qpos:
|
||||
logging.info(f'qpos: {qpos}')
|
||||
self.arm.set_joint_canfd_position(qpos[7:13])
|
||||
self.arm.set_gripper_position(qpos[13])
|
||||
time.sleep(0.035)
|
||||
|
||||
|
||||
def run(self):
|
||||
self.load_hdf5()
|
||||
if self.validate_data():
|
||||
self.control_robot()
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||
validator = DataValidator(config)
|
||||
validator.run()
|
||||
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import yaml
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
class DataVisualizer:
|
||||
def __init__(self, config):
|
||||
self.dataset_dir = config['dataset_dir']
|
||||
self.episode_idx = config['episode_idx']
|
||||
self.dt = 1/config['FPS']
|
||||
self.joint_names = config['joint_names']
|
||||
self.state_names = self.joint_names + ["gripper"]
|
||||
# self.camera_names = config['camera_names']
|
||||
|
||||
def join_file_path(self, file_name):
|
||||
self.dataset_path = os.path.join(self.dataset_dir, file_name)
|
||||
|
||||
def load_hdf5(self):
|
||||
if not os.path.isfile(self.dataset_path):
|
||||
logging.error(f'Dataset does not exist at {self.dataset_path}')
|
||||
exit()
|
||||
|
||||
with h5py.File(self.dataset_path, 'r') as root:
|
||||
self.is_sim = root.attrs['sim']
|
||||
self.qpos = root['/observations/qpos'][()]
|
||||
# self.qvel = root['/observations/qvel'][()]
|
||||
# self.effort = root['/observations/effort'][()]
|
||||
self.action = root['/action'][()]
|
||||
self.image_dict = {cam_name: root[f'/observations/images/{cam_name}'][()]
|
||||
for cam_name in root[f'/observations/images/'].keys()}
|
||||
|
||||
def save_videos(self, video, dt, video_path=None):
|
||||
if isinstance(video, list):
|
||||
cam_names = list(video[0].keys())
|
||||
h, w, _ = video[0][cam_names[0]].shape
|
||||
w = w * len(cam_names)
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for image_dict in video:
|
||||
images = [image_dict[cam_name][:, :, [2, 1, 0]] for cam_name in cam_names]
|
||||
out.write(np.concatenate(images, axis=1))
|
||||
out.release()
|
||||
logging.info(f'Saved video to: {video_path}')
|
||||
elif isinstance(video, dict):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = np.concatenate([video[cam_name] for cam_name in cam_names], axis=2)
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
out.write(all_cam_videos[t][:, :, [2, 1, 0]])
|
||||
out.release()
|
||||
logging.info(f'Saved video to: {video_path}')
|
||||
|
||||
def visualize_joints(self, qpos_list, command_list, plot_path, ylim=None, label_overwrite=None):
|
||||
label1, label2 = label_overwrite if label_overwrite else ('State', 'Command')
|
||||
qpos = np.array(qpos_list)
|
||||
command = np.array(command_list)
|
||||
num_ts, num_dim = qpos.shape
|
||||
logging.info(f'qpos shape: {qpos.shape}, command shape: {command.shape}')
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||
|
||||
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1)
|
||||
ax.plot(command[:, dim_idx], label=label2)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
if ylim:
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_single(self, data_list, label, plot_path, ylim=None):
|
||||
data = np.array(data_list)
|
||||
num_ts, num_dim = data.shape
|
||||
fig, axs = plt.subplots(num_dim, 1, figsize=(num_dim, 2 * num_dim))
|
||||
|
||||
all_names = [name + '_left' for name in self.state_names] + [name + '_right' for name in self.state_names]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(data[:, dim_idx], label=label)
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
if ylim:
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved {label} plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def visualize_timestamp(self, t_list):
|
||||
plot_path = self.dataset_path.replace('.hdf5', '_timestamp.png')
|
||||
fig, axs = plt.subplots(2, 1, figsize=(10, 8))
|
||||
t_float = np.array([secs + nsecs * 1e-9 for secs, nsecs in t_list])
|
||||
|
||||
axs[0].plot(np.arange(len(t_float)), t_float)
|
||||
axs[0].set_title('Camera frame timestamps')
|
||||
axs[0].set_xlabel('timestep')
|
||||
axs[0].set_ylabel('time (sec)')
|
||||
|
||||
axs[1].plot(np.arange(len(t_float) - 1), t_float[:-1] - t_float[1:])
|
||||
axs[1].set_title('dt')
|
||||
axs[1].set_xlabel('timestep')
|
||||
axs[1].set_ylabel('time (sec)')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
logging.info(f'Saved timestamp plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
def run(self):
|
||||
for file_name in os.listdir(self.dataset_dir):
|
||||
if file_name.endswith('.hdf5'):
|
||||
self.join_file_path(file_name)
|
||||
self.load_hdf5()
|
||||
video_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_video.mp4'))
|
||||
self.save_videos(self.image_dict, self.dt, video_path)
|
||||
qpos_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_qpos.png'))
|
||||
self.visualize_joints(self.qpos, self.action, qpos_plot_path)
|
||||
# effort_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_effort.png'))
|
||||
# self.visualize_single(self.effort, 'effort', effort_plot_path)
|
||||
# error_plot_path = os.path.join(self.dataset_dir, file_name.replace('.hdf5', '_error.png'))
|
||||
# self.visualize_single(self.action - self.qpos, 'tracking_error', error_plot_path)
|
||||
# self.visualize_timestamp(t_list) # TODO: Add timestamp visualization back
|
||||
|
||||
|
||||
def load_config(config_path):
|
||||
with open(config_path, 'r') as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = load_config('/home/rm/code/shadow_rm_aloha/config/vis_data_path.yaml')
|
||||
visualizer = DataVisualizer(config)
|
||||
visualizer.run()
|
||||
@@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import yaml
|
||||
import rospy
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from shadow_rm_robot.realman_arm import RmArm
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
self.init_left_arm_angle = config["robot_env"]["init_left_arm_angle"]
|
||||
self.init_right_arm_angle = config["robot_env"]["init_right_arm_angle"]
|
||||
self.arm_axis = config["robot_env"]["arm_axis"]
|
||||
self.camera_names = config["camera_names"]
|
||||
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm_sub"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm_sub"], JointState
|
||||
)
|
||||
self.left_slave_arm_pub = rospy.Publisher(
|
||||
config["ros_topics"]["left_slave_arm_pub"], JointState, queue_size=1
|
||||
)
|
||||
self.right_slave_arm_pub = rospy.Publisher(
|
||||
config["ros_topics"]["right_slave_arm_pub"], JointState, queue_size=1
|
||||
)
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.1,
|
||||
)
|
||||
|
||||
self.ats.registerCallback(self.callback)
|
||||
self.ts = None
|
||||
self.is_frist_step = True
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_slave_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
self.camera_names[0]: camera_head_np_img,
|
||||
self.camera_names[1]: camera_bottom_np_img,
|
||||
self.camera_names[2]: camera_left_np_img,
|
||||
self.camera_names[3]: camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
self.ts = dm_env.TimeStep(
|
||||
step_type=(
|
||||
dm_env.StepType.FIRST if self.is_frist_step else dm_env.StepType.MID
|
||||
),
|
||||
reward=0.0,
|
||||
discount=1.0,
|
||||
observation=obs,
|
||||
)
|
||||
|
||||
def reset(self):
|
||||
|
||||
left_joint_state = JointState()
|
||||
left_joint_state.header.stamp = rospy.Time.now()
|
||||
left_joint_state.name = ["joint_j"]
|
||||
left_joint_state.position = self.init_left_arm_angle[0 : self.arm_axis + 1]
|
||||
|
||||
right_joint_state = JointState()
|
||||
right_joint_state.header.stamp = rospy.Time.now()
|
||||
right_joint_state.name = ["joint_j"]
|
||||
right_joint_state.position = self.init_right_arm_angle[0 : self.arm_axis + 1]
|
||||
|
||||
self.left_slave_arm_pub.publish(left_joint_state)
|
||||
self.right_slave_arm_pub.publish(right_joint_state)
|
||||
while self.ts is None:
|
||||
time.sleep(0.002)
|
||||
|
||||
return self.ts
|
||||
|
||||
|
||||
|
||||
def step(self, target_angle):
|
||||
self.is_frist_step = False
|
||||
left_joint_state = JointState()
|
||||
left_joint_state.header.stamp = rospy.Time.now()
|
||||
left_joint_state.name = ["joint_canfd"]
|
||||
left_joint_state.position = target_angle[0 : self.arm_axis + 1]
|
||||
# print("left_joint_state: ", left_joint_state)
|
||||
|
||||
right_joint_state = JointState()
|
||||
right_joint_state.header.stamp = rospy.Time.now()
|
||||
right_joint_state.name = ["joint_canfd"]
|
||||
right_joint_state.position = target_angle[self.arm_axis + 1 : (self.arm_axis + 1) * 2]
|
||||
# print("right_joint_state: ", right_joint_state)
|
||||
|
||||
self.left_slave_arm_pub.publish(left_joint_state)
|
||||
self.right_slave_arm_pub.publish(right_joint_state)
|
||||
# time.sleep(0.013)
|
||||
return self.ts
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
data = np.concatenate([self.init_left_arm_angle, self.init_right_arm_angle])
|
||||
self.reset()
|
||||
# print("data: ", data)
|
||||
while not rospy.is_shutdown():
|
||||
self.step(data)
|
||||
rospy.sleep(0.010)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
synchronizer = DataSynchronizer("/home/rm/code/shadow_act/config/config.yaml")
|
||||
start_time = time.time()
|
||||
synchronizer.run()
|
||||
@@ -0,0 +1,280 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import time
|
||||
import h5py
|
||||
import yaml
|
||||
import rospy
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataCollector:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
dataset_name,
|
||||
max_timesteps,
|
||||
camera_names,
|
||||
state_dim,
|
||||
overwrite=False,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.init_dict()
|
||||
self.create_dataset_dir()
|
||||
|
||||
def init_dict(self):
|
||||
self.data_dict = {
|
||||
"/observations/qpos": [],
|
||||
"/observations/qvel": [],
|
||||
"/observations/effort": [],
|
||||
"/action": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def create_file(self):
|
||||
# 检查数据集名称是否存在,如果存在则递增名称
|
||||
counter = 0
|
||||
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||
if not self.overwrite:
|
||||
while os.path.exists(dataset_path + ".hdf5"):
|
||||
dataset_path = os.path.join(
|
||||
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||
)
|
||||
counter += 1
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||
self.data_dict["/action"].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||
ts.observation["images"][cam_name]
|
||||
)
|
||||
|
||||
def save_data(self):
|
||||
self.create_file()
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(
|
||||
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||
) as root:
|
||||
root.attrs["sim"] = False
|
||||
obs = root.create_group("observations")
|
||||
image = obs.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, 480, 640, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, 480, 640, 3),
|
||||
)
|
||||
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||
return True
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
rospy.loginfo("ROS node initialized")
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.arm_axis = config["arm_axis"]
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_master_arm"], JointState
|
||||
)
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm"], JointState
|
||||
)
|
||||
self.right_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_master_arm"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm"], JointState
|
||||
)
|
||||
self.left_aloha_state_pub = rospy.Subscriber(
|
||||
config["ros_topics"]["left_aloha_state"],
|
||||
Int32MultiArray,
|
||||
self.aloha_state_callback,
|
||||
)
|
||||
|
||||
rospy.loginfo("Subscribers created")
|
||||
self.camera_names = config["camera_names"]
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_master_arm_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_master_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.05,
|
||||
)
|
||||
self.ats.registerCallback(self.callback)
|
||||
rospy.loginfo("Time synchronizer created and callback registered")
|
||||
|
||||
self.data_collector = DataCollector(
|
||||
dataset_dir=config["dataset_dir"],
|
||||
dataset_name=config["dataset_name"],
|
||||
max_timesteps=config["max_timesteps"],
|
||||
camera_names=config["camera_names"],
|
||||
state_dim=config["state_dim"],
|
||||
overwrite=config["overwrite"],
|
||||
)
|
||||
self.timesteps_collected = 0
|
||||
self.begin_collect = False
|
||||
self.last_time = None
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_master_arm,
|
||||
left_slave_arm,
|
||||
right_master_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
if self.begin_collect:
|
||||
self.timesteps_collected += 1
|
||||
rospy.loginfo(
|
||||
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||
)
|
||||
else:
|
||||
self.timesteps_collected = 0
|
||||
return
|
||||
if self.timesteps_collected == 0:
|
||||
return
|
||||
current_time = time.time()
|
||||
if self.last_time is not None:
|
||||
frequency = 1.0 / (current_time - self.last_time)
|
||||
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||
self.last_time = current_time
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
# 提取臂的角度,速度,力
|
||||
left_master_arm_angle = left_master_arm.position
|
||||
# left_master_arm_velocity = left_master_arm.velocity
|
||||
# left_master_arm_force = left_master_arm.effort
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_master_arm_angle = right_master_arm.position
|
||||
# right_master_arm_velocity = right_master_arm.velocity
|
||||
# right_master_arm_force = right_master_arm.effort
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
self.camera_names[0]: camera_head_np_img,
|
||||
self.camera_names[1]: camera_bottom_np_img,
|
||||
self.camera_names[2]: camera_left_np_img,
|
||||
self.camera_names[3]: camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
print(self.camera_names[0])
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||
)
|
||||
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||
self.data_collector.collect_data(ts, action)
|
||||
|
||||
# 检查是否收集了足够的数据
|
||||
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||
self.data_collector.save_data()
|
||||
|
||||
rospy.loginfo("Data collection complete")
|
||||
|
||||
self.data_collector.init_dict()
|
||||
self.begin_collect = False
|
||||
self.timesteps_collected = 0
|
||||
|
||||
def aloha_state_callback(self, data):
|
||||
if not self.begin_collect:
|
||||
self.aloha_state = data.data
|
||||
print(self.aloha_state[0], self.aloha_state[1])
|
||||
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||
self.begin_collect = True
|
||||
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
|
||||
rospy.spin()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
synchronizer = DataSynchronizer(
|
||||
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||
)
|
||||
synchronizer.run()
|
||||
@@ -0,0 +1,63 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="aloha_state_topic" value="/left_slave_arm_aloha_state" type= "string"/>
|
||||
<param name="hz" value="120" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="aloha_state_topic" value="/right_slave_arm_aloha_state" type= "string"/>
|
||||
<param name="hz" value="120" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左主臂节点 -->
|
||||
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右主臂节点 -->
|
||||
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/servo_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="151222072576" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="150622070125" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="241122071186" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="152122078546" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,49 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="joint_actions_topic" value="/left_slave_arm_joint_actions" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_pub_sub.py" output="screen">
|
||||
<param name="arm_config" value="/home/rm/code/shadow_rm_aloha/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="joint_actions_topic" value="/right_slave_arm_joint_actions" type= "string"/>
|
||||
<param name="hz" value="90" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="151222072576" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="150622070125" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="241122071186" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="152122078546" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,61 @@
|
||||
<launch>
|
||||
<!-- 左从臂节点 -->
|
||||
<node name="slave_arm_publisher_left" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_slave_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右从臂节点 -->
|
||||
<node name="slave_arm_publisher_right" pkg="shadow_rm_aloha" type="slave_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/rm_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_slave_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左主臂节点 -->
|
||||
<node name="master_arm_publisher_left" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_left_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/left_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右主臂节点 -->
|
||||
<node name="master_arm_publisher_right" pkg="shadow_rm_aloha" type="master_arm_publisher.py" output="screen">
|
||||
<param name="arm_config" value="/home/wang/project/shadow_rm_aloha-main/config/servo_right_arm.yaml" type= "string"/>
|
||||
<param name="joint_states_topic" value="/right_master_arm_joint_states" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 右臂相机节点 -->
|
||||
<node name="camera_publisher_right" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="216322070299" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_right/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_right/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 左臂相机节点 -->
|
||||
<node name="camera_publisher_left" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="216322074992" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_left/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_left/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 顶部相机节点 -->
|
||||
<node name="camera_publisher_head" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="215322076086" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_head/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_head/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
|
||||
<!-- 底部相机节点 -->
|
||||
<node name="camera_publisher_bottom" pkg="shadow_rm_aloha" type="camera_publisher.py" output="screen">
|
||||
<param name="serial_number" value="215222074360" type= "string"/>
|
||||
<param name="rgb_topic" value="/camera_bottom/rgb/image_raw" type= "string"/>
|
||||
<param name="depth_topic" value="/camera_bottom/depth/image_raw" type= "string"/>
|
||||
<param name="hz" value="50" type= "int"/>
|
||||
</node>
|
||||
</launch>
|
||||
@@ -0,0 +1,284 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import h5py
|
||||
import yaml
|
||||
import json
|
||||
import rospy
|
||||
import dm_env
|
||||
import socket
|
||||
import numpy as np
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from std_msgs.msg import Int32MultiArray
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from message_filters import Subscriber, ApproximateTimeSynchronizer
|
||||
|
||||
|
||||
class DataCollector:
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
dataset_name,
|
||||
max_timesteps,
|
||||
camera_names,
|
||||
state_dim,
|
||||
overwrite=False,
|
||||
):
|
||||
self.dataset_dir = dataset_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_timesteps = max_timesteps
|
||||
self.camera_names = camera_names
|
||||
self.state_dim = state_dim
|
||||
self.overwrite = overwrite
|
||||
self.init_dict()
|
||||
self.create_dataset_dir()
|
||||
|
||||
def init_dict(self):
|
||||
self.data_dict = {
|
||||
"/observations/qpos": [],
|
||||
"/observations/qvel": [],
|
||||
"/observations/effort": [],
|
||||
"/action": [],
|
||||
}
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"] = []
|
||||
|
||||
def create_dataset_dir(self):
|
||||
# 按照年月日创建目录
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
self.dataset_dir = os.path.join(self.dataset_dir, date_str)
|
||||
if not os.path.exists(self.dataset_dir):
|
||||
os.makedirs(self.dataset_dir)
|
||||
|
||||
def create_file(self):
|
||||
# 检查数据集名称是否存在,如果存在则递增名称
|
||||
counter = 0
|
||||
dataset_path = os.path.join(self.dataset_dir, f"{self.dataset_name}_{counter}")
|
||||
if not self.overwrite:
|
||||
while os.path.exists(dataset_path + ".hdf5"):
|
||||
dataset_path = os.path.join(
|
||||
self.dataset_dir, f"{self.dataset_name}_{counter}"
|
||||
)
|
||||
counter += 1
|
||||
self.dataset_path = dataset_path
|
||||
|
||||
def collect_data(self, ts, action):
|
||||
self.data_dict["/observations/qpos"].append(ts.observation["qpos"])
|
||||
self.data_dict["/observations/qvel"].append(ts.observation["qvel"])
|
||||
self.data_dict["/observations/effort"].append(ts.observation["effort"])
|
||||
self.data_dict["/action"].append(action)
|
||||
for cam_name in self.camera_names:
|
||||
self.data_dict[f"/observations/images/{cam_name}"].append(
|
||||
ts.observation["images"][cam_name]
|
||||
)
|
||||
|
||||
def save_data(self):
|
||||
self.create_file()
|
||||
t0 = time.time()
|
||||
# 保存数据
|
||||
with h5py.File(
|
||||
self.dataset_path + ".hdf5", mode="w", rdcc_nbytes=1024**2 * 2
|
||||
) as root:
|
||||
root.attrs["sim"] = False
|
||||
obs = root.create_group("observations")
|
||||
image = obs.create_group("images")
|
||||
for cam_name in self.camera_names:
|
||||
_ = image.create_dataset(
|
||||
cam_name,
|
||||
(self.max_timesteps, 480, 640, 3),
|
||||
dtype="uint8",
|
||||
chunks=(1, 480, 640, 3),
|
||||
)
|
||||
_ = obs.create_dataset("qpos", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("qvel", (self.max_timesteps, self.state_dim))
|
||||
_ = obs.create_dataset("effort", (self.max_timesteps, self.state_dim))
|
||||
_ = root.create_dataset("action", (self.max_timesteps, self.state_dim))
|
||||
|
||||
for name, array in self.data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f"Saving: {time.time() - t0:.1f} secs")
|
||||
return True
|
||||
|
||||
|
||||
class DataSynchronizer:
|
||||
def __init__(self, config_path="config"):
|
||||
rospy.init_node("synchronizer", anonymous=True)
|
||||
rospy.loginfo("ROS node initialized")
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
self.arm_axis = config["arm_axis"]
|
||||
# 创建订阅者
|
||||
self.camera_left_sub = Subscriber(config["ros_topics"]["camera_left"], Image)
|
||||
self.camera_right_sub = Subscriber(config["ros_topics"]["camera_right"], Image)
|
||||
self.camera_bottom_sub = Subscriber(
|
||||
config["ros_topics"]["camera_bottom"], Image
|
||||
)
|
||||
self.camera_head_sub = Subscriber(config["ros_topics"]["camera_head"], Image)
|
||||
|
||||
self.left_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_master_arm"], JointState
|
||||
)
|
||||
self.left_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["left_slave_arm"], JointState
|
||||
)
|
||||
self.right_master_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_master_arm"], JointState
|
||||
)
|
||||
self.right_slave_arm_sub = Subscriber(
|
||||
config["ros_topics"]["right_slave_arm"], JointState
|
||||
)
|
||||
self.left_aloha_state_pub = rospy.Subscriber(
|
||||
config["ros_topics"]["left_aloha_state"],
|
||||
Int32MultiArray,
|
||||
self.aloha_state_callback,
|
||||
)
|
||||
|
||||
rospy.loginfo("Subscribers created")
|
||||
|
||||
# 创建同步器
|
||||
self.ats = ApproximateTimeSynchronizer(
|
||||
[
|
||||
self.camera_left_sub,
|
||||
self.camera_right_sub,
|
||||
self.camera_bottom_sub,
|
||||
self.camera_head_sub,
|
||||
self.left_master_arm_sub,
|
||||
self.left_slave_arm_sub,
|
||||
self.right_master_arm_sub,
|
||||
self.right_slave_arm_sub,
|
||||
],
|
||||
queue_size=1,
|
||||
slop=0.05,
|
||||
)
|
||||
self.ats.registerCallback(self.callback)
|
||||
rospy.loginfo("Time synchronizer created and callback registered")
|
||||
|
||||
self.data_collector = DataCollector(
|
||||
dataset_dir=config["dataset_dir"],
|
||||
dataset_name=config["dataset_name"],
|
||||
max_timesteps=config["max_timesteps"],
|
||||
camera_names=config["camera_names"],
|
||||
state_dim=config["state_dim"],
|
||||
overwrite=config["overwrite"],
|
||||
)
|
||||
self.timesteps_collected = 0
|
||||
self.begin_collect = False
|
||||
self.last_time = None
|
||||
|
||||
def callback(
|
||||
self,
|
||||
camera_left_img,
|
||||
camera_right_img,
|
||||
camera_bottom_img,
|
||||
camera_head_img,
|
||||
left_master_arm,
|
||||
left_slave_arm,
|
||||
right_master_arm,
|
||||
right_slave_arm,
|
||||
):
|
||||
if self.begin_collect:
|
||||
self.timesteps_collected += 1
|
||||
rospy.loginfo(
|
||||
f"Collecting data: {self.timesteps_collected}/{self.data_collector.max_timesteps}"
|
||||
)
|
||||
else:
|
||||
self.timesteps_collected = 0
|
||||
return
|
||||
if self.timesteps_collected == 0:
|
||||
return
|
||||
current_time = time.time()
|
||||
if self.last_time is not None:
|
||||
frequency = 1.0 / (current_time - self.last_time)
|
||||
rospy.loginfo(f"Callback frequency: {frequency:.2f} Hz")
|
||||
self.last_time = current_time
|
||||
# 将ROS图像消息转换为NumPy数组
|
||||
camera_left_np_img = np.frombuffer(
|
||||
camera_left_img.data, dtype=np.uint8
|
||||
).reshape(camera_left_img.height, camera_left_img.width, -1)
|
||||
camera_right_np_img = np.frombuffer(
|
||||
camera_right_img.data, dtype=np.uint8
|
||||
).reshape(camera_right_img.height, camera_right_img.width, -1)
|
||||
camera_bottom_np_img = np.frombuffer(
|
||||
camera_bottom_img.data, dtype=np.uint8
|
||||
).reshape(camera_bottom_img.height, camera_bottom_img.width, -1)
|
||||
camera_head_np_img = np.frombuffer(
|
||||
camera_head_img.data, dtype=np.uint8
|
||||
).reshape(camera_head_img.height, camera_head_img.width, -1)
|
||||
|
||||
# 提取臂的角度,速度,力
|
||||
left_master_arm_angle = left_master_arm.position
|
||||
# left_master_arm_velocity = left_master_arm.velocity
|
||||
# left_master_arm_force = left_master_arm.effort
|
||||
|
||||
left_slave_arm_angle = left_slave_arm.position
|
||||
left_slave_arm_velocity = left_slave_arm.velocity
|
||||
left_slave_arm_force = left_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同, 非因时夹爪请注释
|
||||
# left_slave_arm_angle[self.arm_axis] = left_master_arm_angle[self.arm_axis]
|
||||
|
||||
right_master_arm_angle = right_master_arm.position
|
||||
# right_master_arm_velocity = right_master_arm.velocity
|
||||
# right_master_arm_force = right_master_arm.effort
|
||||
|
||||
right_slave_arm_angle = right_slave_arm.position
|
||||
right_slave_arm_velocity = right_slave_arm.velocity
|
||||
right_slave_arm_force = right_slave_arm.effort
|
||||
# 因时夹爪的角度与主臂的角度相同,, 非因时夹爪请注释
|
||||
# right_slave_arm_angle[self.arm_axis] = right_master_arm_angle[self.arm_axis]
|
||||
|
||||
# 收集数据
|
||||
obs = collections.OrderedDict(
|
||||
{
|
||||
"qpos": np.concatenate([left_slave_arm_angle, right_slave_arm_angle]),
|
||||
"qvel": np.concatenate(
|
||||
[left_slave_arm_velocity, right_slave_arm_velocity]
|
||||
),
|
||||
"effort": np.concatenate([left_slave_arm_force, right_slave_arm_force]),
|
||||
"images": {
|
||||
"cam_front": camera_head_np_img,
|
||||
"cam_low": camera_bottom_np_img,
|
||||
"cam_left": camera_left_np_img,
|
||||
"cam_right": camera_right_np_img,
|
||||
},
|
||||
}
|
||||
)
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID, reward=0, discount=None, observation=obs
|
||||
)
|
||||
action = np.concatenate([left_master_arm_angle, right_master_arm_angle])
|
||||
self.data_collector.collect_data(ts, action)
|
||||
|
||||
# 检查是否收集了足够的数据
|
||||
if self.timesteps_collected >= self.data_collector.max_timesteps:
|
||||
self.data_collector.save_data()
|
||||
|
||||
rospy.loginfo("Data collection complete")
|
||||
|
||||
self.data_collector.init_dict()
|
||||
self.begin_collect = False
|
||||
self.timesteps_collected = 0
|
||||
|
||||
def aloha_state_callback(self, data):
|
||||
if not self.begin_collect:
|
||||
self.aloha_state = data.data
|
||||
print(self.aloha_state[0], self.aloha_state[1])
|
||||
if self.aloha_state[0] == 1 and self.aloha_state[1] == 1:
|
||||
self.begin_collect = True
|
||||
|
||||
|
||||
def run(self):
|
||||
rospy.loginfo("Starting ROS spin")
|
||||
|
||||
rospy.spin()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
synchronizer = DataSynchronizer(
|
||||
"/home/rm/code/shadow_rm_aloha/config/data_synchronizer.yaml"
|
||||
)
|
||||
synchronizer.run()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
<?xml version="1.0"?>
|
||||
<package format="2">
|
||||
<name>shadow_rm_aloha</name>
|
||||
<version>0.0.1</version>
|
||||
<description>The shadow_rm_aloha package</description>
|
||||
|
||||
<maintainer email="your_email@example.com">Your Name</maintainer>
|
||||
|
||||
<license>TODO</license>
|
||||
|
||||
<buildtool_depend>catkin</buildtool_depend>
|
||||
|
||||
<build_depend>rospy</build_depend>
|
||||
<build_depend>sensor_msgs</build_depend>
|
||||
<build_depend>std_msgs</build_depend>
|
||||
<build_depend>cv_bridge</build_depend>
|
||||
<build_depend>image_transport</build_depend>
|
||||
<build_depend>message_generation</build_depend>
|
||||
<build_depend>message_runtime</build_depend>
|
||||
|
||||
<exec_depend>rospy</exec_depend>
|
||||
<exec_depend>sensor_msgs</exec_depend>
|
||||
<exec_depend>std_msgs</exec_depend>
|
||||
<exec_depend>cv_bridge</exec_depend>
|
||||
<exec_depend>image_transport</exec_depend>
|
||||
<exec_depend>message_runtime</exec_depend>
|
||||
|
||||
|
||||
<export>
|
||||
</export>
|
||||
</package>
|
||||
@@ -0,0 +1,5 @@
|
||||
# GetArmStatus.srv
|
||||
|
||||
---
|
||||
sensor_msgs/JointState joint_status
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# GetImage.srv
|
||||
---
|
||||
bool success
|
||||
sensor_msgs/Image image
|
||||
@@ -0,0 +1,4 @@
|
||||
# MoveArm.srv
|
||||
float32[] joint_angle
|
||||
---
|
||||
bool success
|
||||
@@ -0,0 +1 @@
|
||||
__version__ = '0.1.0'
|
||||
49
realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py
Normal file
49
realman_src/realman_aloha/shadow_rm_aloha/test/mu_test.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import multiprocessing as mp
|
||||
import time
|
||||
|
||||
def collect_data(arm_id, cam_id, data_queue, lock):
|
||||
while True:
|
||||
# 模拟数据采集
|
||||
arm_data = f"Arm {arm_id} data"
|
||||
cam_data = f"Cam {cam_id} data"
|
||||
|
||||
# 获取当前时间戳
|
||||
timestamp = time.time()
|
||||
|
||||
# 将数据放入队列
|
||||
with lock:
|
||||
data_queue.put((timestamp, arm_data, cam_data))
|
||||
|
||||
# 模拟高帧率
|
||||
time.sleep(0.01)
|
||||
|
||||
def main():
|
||||
num_arms = 4
|
||||
num_cams = 4
|
||||
|
||||
# 创建队列和锁
|
||||
data_queue = mp.Queue()
|
||||
lock = mp.Lock()
|
||||
|
||||
# 创建进程
|
||||
processes = []
|
||||
for i in range(num_arms):
|
||||
p = mp.Process(target=collect_data, args=(i, i, data_queue, lock))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
# 主进程处理数据
|
||||
try:
|
||||
while True:
|
||||
if not data_queue.empty():
|
||||
with lock:
|
||||
timestamp, arm_data, cam_data = data_queue.get()
|
||||
print(f"Timestamp: {timestamp}, {arm_data}, {cam_data}")
|
||||
except KeyboardInterrupt:
|
||||
for p in processes:
|
||||
p.terminate()
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,38 @@
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from shadow_rm_aloha.data_sub_process.aloha_data_synchronizer import DataCollector
|
||||
|
||||
def test_create_dataset_dir():
|
||||
# 设置测试参数
|
||||
dataset_dir = './test_data/dataset'
|
||||
dataset_name = 'test_episode'
|
||||
max_timesteps = 100
|
||||
camera_names = ['cam1', 'cam2']
|
||||
overwrite = False
|
||||
|
||||
# 清理旧的测试数据
|
||||
if os.path.exists(dataset_dir):
|
||||
shutil.rmtree(dataset_dir)
|
||||
|
||||
# 创建 DataCollector 实例并调用 create_dataset_dir
|
||||
collector = DataCollector(dataset_dir, dataset_name, max_timesteps, camera_names, overwrite)
|
||||
|
||||
# 检查目录是否按预期创建
|
||||
date_str = datetime.now().strftime("%Y%m%d")
|
||||
expected_dir = os.path.join(dataset_dir, date_str)
|
||||
assert os.path.exists(expected_dir), f"Expected directory {expected_dir} does not exist."
|
||||
|
||||
# 检查文件名是否按预期递增
|
||||
expected_file = os.path.join(expected_dir, dataset_name + '.hdf5')
|
||||
assert collector.dataset_path == expected_file, f"Expected file path {expected_file}, but got {collector.dataset_path}"
|
||||
|
||||
# 再次调用 create_dataset_dir,检查文件名是否递增
|
||||
# collector.create_dataset_dir()
|
||||
expected_file_incremented = os.path.join(expected_dir, dataset_name + '_1.hdf5')
|
||||
assert collector.dataset_path == expected_file_incremented, f"Expected file path {expected_file_incremented}, but got {collector.dataset_path}"
|
||||
|
||||
print("All tests passed.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_create_dataset_dir()
|
||||
105
realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py
Normal file
105
realman_src/realman_aloha/shadow_rm_aloha/test/udp_test.py
Normal file
@@ -0,0 +1,105 @@
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
import random
|
||||
import socket
|
||||
import json
|
||||
import logging
|
||||
|
||||
# 设置日志级别
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
|
||||
class test_udp():
|
||||
def __init__(self):
|
||||
arm_ip = '192.168.1.19'
|
||||
arm_port = 8080
|
||||
self.arm =socket.socket()
|
||||
self.arm.connect((arm_ip, arm_port))
|
||||
set_udp = {"command":"set_realtime_push","cycle":1,"enable":True,"port":8090,"ip":"192.168.1.101","custom":{"aloha_state":True,"joint_speed":True,"arm_current_status":True,"hand":False, "expand_state":True}}
|
||||
self.arm.send(json.dumps(set_udp).encode('utf-8'))
|
||||
state = self.arm.recv(1024)
|
||||
|
||||
logging.info(f"Send data to {arm_ip}:{arm_port}: {state}")
|
||||
|
||||
self.udp_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
|
||||
# 设置套接字选项,允许端口复用
|
||||
self.udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
local_ip = "192.168.1.101"
|
||||
local_port = 8090
|
||||
self.udp_socket.bind((local_ip, local_port))
|
||||
self.BUFFER_SIZE = 1024
|
||||
|
||||
|
||||
|
||||
def set_udp(self):
|
||||
|
||||
while True:
|
||||
start_time = time.time()
|
||||
data, addr = self.udp_socket.recvfrom(self.BUFFER_SIZE)
|
||||
# 将接收到的UDP数据解码并解析为JSON
|
||||
data = json.loads(data.decode('utf-8'))
|
||||
end_time = time.time()
|
||||
print(f"Received data {data}")
|
||||
|
||||
udp_socket.close()
|
||||
|
||||
|
||||
def collect_arm_data(arm_id, queue, event):
|
||||
while True:
|
||||
data = f"Arm {arm_id} data {random.random()}"
|
||||
queue.put((arm_id, data))
|
||||
event.set()
|
||||
time.sleep(1)
|
||||
|
||||
def collect_camera_data(camera_id, queue, event):
|
||||
while True:
|
||||
data = f"Camera {camera_id} data {random.random()}"
|
||||
queue.put((camera_id, data))
|
||||
event.set()
|
||||
time.sleep(1)
|
||||
|
||||
def main():
|
||||
arm_queues = [multiprocessing.Queue() for _ in range(4)]
|
||||
camera_queues = [multiprocessing.Queue() for _ in range(4)]
|
||||
arm_events = [multiprocessing.Event() for _ in range(4)]
|
||||
camera_events = [multiprocessing.Event() for _ in range(4)]
|
||||
|
||||
arm_processes = [multiprocessing.Process(target=collect_arm_data, args=(i, arm_queues[i], arm_events[i])) for i in range(4)]
|
||||
camera_processes = [multiprocessing.Process(target=collect_camera_data, args=(i, camera_queues[i], camera_events[i])) for i in range(4)]
|
||||
|
||||
for p in arm_processes + camera_processes:
|
||||
p.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
for event in arm_events + camera_events:
|
||||
event.wait()
|
||||
|
||||
for i in range(4):
|
||||
if not arm_queues[i].empty():
|
||||
arm_id, arm_data = arm_queues[i].get()
|
||||
print(f"Received from Arm {arm_id}: {arm_data}")
|
||||
arm_events[i].clear()
|
||||
|
||||
if not camera_queues[i].empty():
|
||||
camera_id, camera_data = camera_queues[i].get()
|
||||
print(f"Received from Camera {camera_id}: {camera_data}")
|
||||
camera_events[i].clear()
|
||||
|
||||
time.sleep(0.1)
|
||||
except KeyboardInterrupt:
|
||||
for p in arm_processes + camera_processes:
|
||||
p.terminate()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# test_udp = test_udp()
|
||||
# test_udp.set_udp()
|
||||
Reference in New Issue
Block a user