Compare commits

...

8 Commits

Author SHA1 Message Date
25fb9c0d33 add convert aloha 2 lerobot 2025-04-20 21:50:20 +08:00
722de584d2 当前代码有问题 2025-04-13 21:43:26 +08:00
a4fe5ee09a 使用原生的数据搜集代码 2025-04-13 21:41:45 +08:00
3df284ddd1 Improve the gui_app for data collection 2025-04-08 22:52:27 +08:00
88885a6a25 add gui & modify camera view 2025-04-08 12:11:54 +08:00
aa3920dd28 fix 2025-04-07 22:03:32 +08:00
e034881507 restructure code 2025-04-07 20:32:39 +08:00
d843a990a3 modify code structure 2025-04-07 19:45:34 +08:00
51 changed files with 3389 additions and 1816 deletions

5
.gitignore vendored
View File

@@ -1,2 +1,5 @@
cobot_magic/
librealsense/
librealsense/
data*/
outputs/
lerobot_datasets/

View File

@@ -0,0 +1 @@
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5

View File

@@ -1,3 +0,0 @@
python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data
python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1

Binary file not shown.

Binary file not shown.

Binary file not shown.

23
collect_data/aloha.yaml Normal file
View File

@@ -0,0 +1,23 @@
camera_names:
- cam_high
- cam_left_wrist
- cam_right_wrist
dataset_dir: /home/ubuntu/LYT/lerobot_aloha/datasets/3camera
episode_idx: 0
frame_rate: 30
img_front_depth_topic: /camera_f/depth/image_raw
img_front_topic: /camera_f/color/image_raw
img_left_depth_topic: /camera_l/depth/image_raw
img_left_topic: /camera_l/color/image_raw
img_right_depth_topic: /camera_r/depth/image_raw
img_right_topic: /camera_r/color/image_raw
master_arm_left_topic: /master/joint_left
master_arm_right_topic: /master/joint_right
max_timesteps: 500
num_episodes: 50
puppet_arm_left_topic: /puppet/joint_left
puppet_arm_right_topic: /puppet/joint_right
robot_base_topic: /odom
task_name: aloha_mobile_dummy
use_depth_image: false
use_robot_base: false

View File

@@ -0,0 +1,305 @@
import cv2
import numpy as np
import dm_env
import collections
from collections import deque
import rospy
from sensor_msgs.msg import JointState
from sensor_msgs.msg import Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
from utils import display_camera_grid, save_data
class AlohaRobotRos:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.master_arm_right_deque = None
self.master_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.args = args
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.master_arm_left_deque = deque()
self.master_arm_right_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
def get_frame(self):
print(len(self.img_left_deque), len(self.img_right_deque), len(self.img_front_deque),
len(self.img_left_depth_deque), len(self.img_right_depth_deque), len(self.img_front_depth_deque))
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.master_arm_left_deque) == 0 or self.master_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.master_arm_right_deque) == 0 or self.master_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
# print("img_left:", img_left.shape)
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.master_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.master_arm_left_deque.popleft()
master_arm_left = self.master_arm_left_deque.popleft()
while self.master_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.master_arm_right_deque.popleft()
master_arm_right = self.master_arm_right_deque.popleft()
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_left_depth = cv2.copyMakeBorder(img_left_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_right_depth = cv2.copyMakeBorder(img_right_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_front_depth = cv2.copyMakeBorder(img_front_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
# import pdb
# pdb.set_trace()
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def master_arm_left_callback(self, msg):
if len(self.master_arm_left_deque) >= 2000:
self.master_arm_left_deque.popleft()
self.master_arm_left_deque.append(msg)
def master_arm_right_callback(self, msg):
if len(self.master_arm_right_deque) >= 2000:
self.master_arm_right_deque.popleft()
self.master_arm_right_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def init_ros(self):
rospy.init_node('record_episodes', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.master_arm_left_topic, JointState, self.master_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.master_arm_right_topic, JointState, self.master_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
def process(self):
timesteps = []
actions = []
# 图像数据
image = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8)
image_dict = dict()
for cam_name in self.args.camera_names:
image_dict[cam_name] = image
count = 0
# input_key = input("please input s:")
# while input_key != 's' and not rospy.is_shutdown():
# input_key = input("please input s:")
rate = rospy.Rate(self.args.frame_rate)
print_flag = True
while (count < self.args.max_timesteps + 1) and not rospy.is_shutdown():
# 2 收集数据
result = self.get_frame()
# import pdb
# pdb.set_trace()
if not result:
if print_flag:
print("syn fail\n")
print_flag = False
rate.sleep()
continue
print_flag = True
count += 1
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) = result
# 2.1 图像信息
image_dict = dict()
image_dict[self.args.camera_names[0]] = img_front
image_dict[self.args.camera_names[1]] = img_left
image_dict[self.args.camera_names[2]] = img_right
# import pdb
# pdb.set_trace()
display_camera_grid(image_dict)
# 2.2 从臂的信息从臂的状态 机械臂示教模式时 会自动订阅
obs = collections.OrderedDict() # 有序的字典
obs['images'] = image_dict
if self.args.use_depth_image:
image_dict_depth = dict()
image_dict_depth[self.args.camera_names[0]] = img_front_depth
image_dict_depth[self.args.camera_names[1]] = img_left_depth
image_dict_depth[self.args.camera_names[2]] = img_right_depth
obs['images_depth'] = image_dict_depth
obs['qpos'] = np.concatenate((np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
obs['qvel'] = np.concatenate((np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
obs['effort'] = np.concatenate((np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
if self.args.use_robot_base:
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
else:
obs['base_vel'] = [0.0, 0.0]
# 第一帧 只包含first fisrt只保存StepType.FIRST
if count == 1:
ts = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=obs)
timesteps.append(ts)
continue
# 时间步
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=None,
discount=None,
observation=obs)
# 主臂保存状态
action = np.concatenate((np.array(master_arm_left.position), np.array(master_arm_right.position)), axis=0)
actions.append(action)
timesteps.append(ts)
print(f"\n{self.args.episode_idx} | Frame data: {count}\r", end="")
if rospy.is_shutdown():
exit(-1)
rate.sleep()
print("len(timesteps): ", len(timesteps))
print("len(actions) : ", len(actions))
return timesteps, actions

114
collect_data/collect_data.py Executable file
View File

@@ -0,0 +1,114 @@
import os
import time
from aloha_mobile import AlohaRobotRos
from utils import save_data, init_keyboard_listener, load_config, log_say
def main(config_path):
args = load_config(config_path)
ros_operator = AlohaRobotRos(args)
dataset_dir = os.path.join(args.dataset_dir, args.task_name)
# Ensure dataset directory exists
os.makedirs(dataset_dir, exist_ok=True)
# Single episode collection mode
if args.num_episodes == 1:
print(f"Recording single episode {args.episode_idx}...")
timesteps, actions = ros_operator.process()
if len(actions) < args.max_timesteps:
print(f"\033[31m\nSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m\n")
return -1
dataset_path = os.path.join(dataset_dir, f"episode_{args.episode_idx}")
save_data(args, timesteps, actions, dataset_path)
print(f"\033[32mEpisode {args.episode_idx} saved successfully at {dataset_path}\033[0m")
return 0
# Multi-episode collection mode
print("""
\033[1;36mKeyboard Controls:\033[0m
\033[1mLeft Arrow\033[0m: Start Recording
\033[1mRight Arrow\033[0m: Save Current Data
\033[1mDown Arrow\033[0m: Discard Current Data
\033[1mUp Arrow\033[0m: Replay Data (if implemented)
\033[1mESC\033[0m: Exit Program
""")
log_say("欢迎您为 具身智能科学家项目采集数据,您辛苦了。我已经将一切准备就绪,请您按方向左键开始录制数据。", play_sounds=True)
listener, events = init_keyboard_listener()
episode_idx = args.episode_idx
collected_episodes = 0
try:
while collected_episodes < args.num_episodes:
if events["exit_early"]:
print("\033[33mOperation terminated by user\033[0m")
log_say("操作被你停止了,如果这是个误操作,请重新开始。", play_sounds=True)
break
if events["record_start"]:
# Reset event states for new recording
events["record_start"] = False
events["save_data"] = False
events["discard_data"] = False
log_say(f"开始录制第{episode_idx}条轨迹,请开始操作机械臂。", play_sounds=True)
print(f"\n\033[1;32mRecording episode {episode_idx}...\033[0m")
timesteps, actions = ros_operator.process()
print(f"\033[1;33mRecorded {len(actions)} timesteps. (→ to save, ↓ to discard)\033[0m")
log_say(f"{episode_idx}条轨迹的录制已经到达最大时间步并录制结束,请选择是否保留该条轨迹。按方向右键保留,方向下键丢弃。", play_sounds=True)
# Wait for user decision to save or discard
while True:
if events["save_data"]:
events["save_data"] = False
if len(actions) < args.max_timesteps:
print(f"\033[31mSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m")
log_say(f"由于当前轨迹的实际时间步数小于最大时间步数,因此无法保存该条轨迹。该条轨迹将被丢弃,请重新录制。", play_sounds=True)
else:
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
log_say(f"你选择了保留该轨迹作为第{episode_idx}条轨迹数据。接下来请你按方向左键开始录制下一条轨迹。", play_sounds=True)
save_data(args, timesteps, actions, dataset_path)
print(f"\033[32mEpisode {episode_idx} saved successfully at {dataset_path}\033[0m")
episode_idx += 1
collected_episodes += 1
print(f"\033[1mProgress: {collected_episodes}/{args.num_episodes} episodes collected. (← to start new episode)\033[0m")
break
if events["discard_data"]:
events["discard_data"] = False
log_say(f"你选择了丢弃该轨迹作为第{episode_idx}条轨迹数据。接下来请你按方向左键开始录制下一条轨迹。", play_sounds=True)
print("\033[33mData discarded. Press ← to start a new recording.\033[0m")
break
if events["exit_early"]:
print("\033[33mOperation terminated by user\033[0m")
log_say("操作被你停止了,如果这是个误操作,请重新开始。", play_sounds=True)
return 0
time.sleep(0.1) # Reduce CPU usage
time.sleep(0.1) # Reduce CPU usage
if collected_episodes == args.num_episodes:
log_say("恭喜你,本次数据已经全部录制完成。您辛苦了~", play_sounds=True)
print(f"\n\033[1;32mData collection complete! All {args.num_episodes} episodes collected.\033[0m")
finally:
# Ensure listener is cleaned up
if listener:
listener.stop()
print("Keyboard listener stopped")
if __name__ == '__main__':
try:
exit_code = main("/home/ubuntu/LYT/lerobot_aloha/collect_data/aloha.yaml")
exit(exit_code if exit_code is not None else 0)
except KeyboardInterrupt:
print("\n\033[33mProgram interrupted by user\033[0m")
exit(0)
except Exception as e:
print(f"\n\033[31mError: {e}\033[0m")

View File

@@ -0,0 +1,403 @@
import sys
import os
import yaml
from collect_data import main
from types import SimpleNamespace
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
QLabel, QLineEdit, QPushButton, QSpinBox, QCheckBox,
QGroupBox, QTabWidget, QMessageBox, QFileDialog)
from PyQt5.QtCore import Qt
class AlohaDataCollectionGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("MindRobot-V1 Data Collection")
self.setGeometry(100, 100, 800, 600)
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.main_layout = QVBoxLayout(self.central_widget)
self.config_path = os.path.expanduser("/home/ubuntu/LYT/lerobot_aloha/collect_data/aloha.yaml")
self.create_ui()
self.setup_connections()
self.load_default_config()
def create_ui(self):
# Create tabs
self.tabs = QTabWidget()
self.main_layout.addWidget(self.tabs)
# General Settings Tab
self.general_tab = QWidget()
self.tabs.addTab(self.general_tab, "General Settings")
self.create_general_tab()
# Camera Settings Tab
self.camera_tab = QWidget()
self.tabs.addTab(self.camera_tab, "Camera Settings")
self.create_camera_tab()
# Arm Settings Tab
self.arm_tab = QWidget()
self.tabs.addTab(self.arm_tab, "Arm Settings")
self.create_arm_tab()
# Control Buttons
self.control_group = QGroupBox("Control")
control_layout = QHBoxLayout()
self.load_config_button = QPushButton("Load Config")
self.save_config_button = QPushButton("Save Config")
self.start_button = QPushButton("Start Recording")
self.stop_button = QPushButton("Stop Recording")
self.exit_button = QPushButton("Exit")
control_layout.addWidget(self.load_config_button)
control_layout.addWidget(self.save_config_button)
control_layout.addWidget(self.start_button)
control_layout.addWidget(self.stop_button)
control_layout.addWidget(self.exit_button)
self.control_group.setLayout(control_layout)
self.main_layout.addWidget(self.control_group)
def create_general_tab(self):
layout = QVBoxLayout(self.general_tab)
# Config File Path
config_group = QGroupBox("Configuration File")
config_layout = QHBoxLayout()
self.config_path_edit = QLineEdit(self.config_path)
self.browse_config_button = QPushButton("Browse...")
config_layout.addWidget(QLabel("Config File:"))
config_layout.addWidget(self.config_path_edit)
config_layout.addWidget(self.browse_config_button)
config_group.setLayout(config_layout)
layout.addWidget(config_group)
# Dataset Directory
dir_group = QGroupBox("Dataset Directory")
dir_layout = QHBoxLayout()
self.dataset_dir_edit = QLineEdit()
self.browse_dir_button = QPushButton("Browse...")
dir_layout.addWidget(QLabel("Dataset Directory:"))
dir_layout.addWidget(self.dataset_dir_edit)
dir_layout.addWidget(self.browse_dir_button)
dir_group.setLayout(dir_layout)
layout.addWidget(dir_group)
# Task Settings
task_group = QGroupBox("Task Settings")
task_layout = QVBoxLayout()
self.task_name_edit = QLineEdit()
self.episode_idx_spin = QSpinBox()
self.episode_idx_spin.setRange(0, 9999)
self.max_timesteps_spin = QSpinBox()
self.max_timesteps_spin.setRange(1, 10000)
self.num_episodes_spin = QSpinBox()
self.num_episodes_spin.setRange(1, 1000)
self.frame_rate_spin = QSpinBox()
self.frame_rate_spin.setRange(1, 60)
task_layout.addWidget(QLabel("Task Name:"))
task_layout.addWidget(self.task_name_edit)
task_layout.addWidget(QLabel("Episode Index:"))
task_layout.addWidget(self.episode_idx_spin)
task_layout.addWidget(QLabel("Max Timesteps:"))
task_layout.addWidget(self.max_timesteps_spin)
task_layout.addWidget(QLabel("Number of Episodes:"))
task_layout.addWidget(self.num_episodes_spin)
task_layout.addWidget(QLabel("Frame Rate:"))
task_layout.addWidget(self.frame_rate_spin)
task_group.setLayout(task_layout)
layout.addWidget(task_group)
# Options
options_group = QGroupBox("Options")
options_layout = QVBoxLayout()
self.use_robot_base_check = QCheckBox("Use Robot Base")
self.use_depth_image_check = QCheckBox("Use Depth Image")
options_layout.addWidget(self.use_robot_base_check)
options_layout.addWidget(self.use_depth_image_check)
options_group.setLayout(options_layout)
layout.addWidget(options_group)
layout.addStretch()
def create_camera_tab(self):
layout = QVBoxLayout(self.camera_tab)
# Color Image Topics
color_group = QGroupBox("Color Image Topics")
color_layout = QVBoxLayout()
self.img_front_topic_edit = QLineEdit()
self.img_left_topic_edit = QLineEdit()
self.img_right_topic_edit = QLineEdit()
color_layout.addWidget(QLabel("Front Camera Topic:"))
color_layout.addWidget(self.img_front_topic_edit)
color_layout.addWidget(QLabel("Left Camera Topic:"))
color_layout.addWidget(self.img_left_topic_edit)
color_layout.addWidget(QLabel("Right Camera Topic:"))
color_layout.addWidget(self.img_right_topic_edit)
color_group.setLayout(color_layout)
layout.addWidget(color_group)
# Depth Image Topics
depth_group = QGroupBox("Depth Image Topics")
depth_layout = QVBoxLayout()
self.img_front_depth_topic_edit = QLineEdit()
self.img_left_depth_topic_edit = QLineEdit()
self.img_right_depth_topic_edit = QLineEdit()
depth_layout.addWidget(QLabel("Front Depth Topic:"))
depth_layout.addWidget(self.img_front_depth_topic_edit)
depth_layout.addWidget(QLabel("Left Depth Topic:"))
depth_layout.addWidget(self.img_left_depth_topic_edit)
depth_layout.addWidget(QLabel("Right Depth Topic:"))
depth_layout.addWidget(self.img_right_depth_topic_edit)
depth_group.setLayout(depth_layout)
layout.addWidget(depth_group)
layout.addStretch()
def create_arm_tab(self):
layout = QVBoxLayout(self.arm_tab)
# Master Arm Topics
master_group = QGroupBox("Master Arm Topics")
master_layout = QVBoxLayout()
self.master_arm_left_topic_edit = QLineEdit()
self.master_arm_right_topic_edit = QLineEdit()
master_layout.addWidget(QLabel("Master Left Arm Topic:"))
master_layout.addWidget(self.master_arm_left_topic_edit)
master_layout.addWidget(QLabel("Master Right Arm Topic:"))
master_layout.addWidget(self.master_arm_right_topic_edit)
master_group.setLayout(master_layout)
layout.addWidget(master_group)
# Puppet Arm Topics
puppet_group = QGroupBox("Puppet Arm Topics")
puppet_layout = QVBoxLayout()
self.puppet_arm_left_topic_edit = QLineEdit()
self.puppet_arm_right_topic_edit = QLineEdit()
puppet_layout.addWidget(QLabel("Puppet Left Arm Topic:"))
puppet_layout.addWidget(self.puppet_arm_left_topic_edit)
puppet_layout.addWidget(QLabel("Puppet Right Arm Topic:"))
puppet_layout.addWidget(self.puppet_arm_right_topic_edit)
puppet_group.setLayout(puppet_layout)
layout.addWidget(puppet_group)
# Robot Base Topic
base_group = QGroupBox("Robot Base Topic")
base_layout = QVBoxLayout()
self.robot_base_topic_edit = QLineEdit()
base_layout.addWidget(QLabel("Robot Base Topic:"))
base_layout.addWidget(self.robot_base_topic_edit)
base_group.setLayout(base_layout)
layout.addWidget(base_group)
layout.addStretch()
def setup_connections(self):
self.load_config_button.clicked.connect(self.load_config)
self.save_config_button.clicked.connect(self.save_config)
self.browse_config_button.clicked.connect(self.browse_config_file)
self.browse_dir_button.clicked.connect(self.browse_dataset_dir)
self.start_button.clicked.connect(self.start_recording)
self.stop_button.clicked.connect(self.stop_recording)
self.exit_button.clicked.connect(self.close)
def load_default_config(self):
default_config = {
'dataset_dir': '/home/ubuntu/LYT/lerobot_aloha/datasets/3camera',
'task_name': 'aloha_mobile_dummy',
'episode_idx': 0,
'max_timesteps': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
'num_episodes': 50,
'img_front_topic': '/camera_f/color/image_raw',
'img_left_topic': '/camera_l/color/image_raw',
'img_right_topic': '/camera_r/color/image_raw',
'img_front_depth_topic': '/camera_f/depth/image_raw',
'img_left_depth_topic': '/camera_l/depth/image_raw',
'img_right_depth_topic': '/camera_r/depth/image_raw',
'master_arm_left_topic': '/master/joint_left',
'master_arm_right_topic': '/master/joint_right',
'puppet_arm_left_topic': '/puppet/joint_left',
'puppet_arm_right_topic': '/puppet/joint_right',
'robot_base_topic': '/odom',
'use_robot_base': False,
'use_depth_image': False,
'frame_rate': 30
}
# Update UI with default values
self.update_ui_from_config(default_config)
def update_ui_from_config(self, config):
"""Update UI elements from a config dictionary"""
self.dataset_dir_edit.setText(config.get('dataset_dir', ''))
self.task_name_edit.setText(config.get('task_name', ''))
self.episode_idx_spin.setValue(config.get('episode_idx', 0))
self.max_timesteps_spin.setValue(config.get('max_timesteps', 500))
self.num_episodes_spin.setValue(config.get('num_episodes', 1))
self.frame_rate_spin.setValue(config.get('frame_rate', 30))
self.img_front_topic_edit.setText(config.get('img_front_topic', ''))
self.img_left_topic_edit.setText(config.get('img_left_topic', ''))
self.img_right_topic_edit.setText(config.get('img_right_topic', ''))
self.img_front_depth_topic_edit.setText(config.get('img_front_depth_topic', ''))
self.img_left_depth_topic_edit.setText(config.get('img_left_depth_topic', ''))
self.img_right_depth_topic_edit.setText(config.get('img_right_depth_topic', ''))
self.master_arm_left_topic_edit.setText(config.get('master_arm_left_topic', ''))
self.master_arm_right_topic_edit.setText(config.get('master_arm_right_topic', ''))
self.puppet_arm_left_topic_edit.setText(config.get('puppet_arm_left_topic', ''))
self.puppet_arm_right_topic_edit.setText(config.get('puppet_arm_right_topic', ''))
self.robot_base_topic_edit.setText(config.get('robot_base_topic', ''))
self.use_robot_base_check.setChecked(config.get('use_robot_base', False))
self.use_depth_image_check.setChecked(config.get('use_depth_image', False))
def get_config_from_ui(self):
"""Get current UI values as a config dictionary"""
config = {
'dataset_dir': self.dataset_dir_edit.text(),
'task_name': self.task_name_edit.text(),
'episode_idx': self.episode_idx_spin.value(),
'max_timesteps': self.max_timesteps_spin.value(),
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
'num_episodes': self.num_episodes_spin.value(),
'img_front_topic': self.img_front_topic_edit.text(),
'img_left_topic': self.img_left_topic_edit.text(),
'img_right_topic': self.img_right_topic_edit.text(),
'img_front_depth_topic': self.img_front_depth_topic_edit.text(),
'img_left_depth_topic': self.img_left_depth_topic_edit.text(),
'img_right_depth_topic': self.img_right_depth_topic_edit.text(),
'master_arm_left_topic': self.master_arm_left_topic_edit.text(),
'master_arm_right_topic': self.master_arm_right_topic_edit.text(),
'puppet_arm_left_topic': self.puppet_arm_left_topic_edit.text(),
'puppet_arm_right_topic': self.puppet_arm_right_topic_edit.text(),
'robot_base_topic': self.robot_base_topic_edit.text(),
'use_robot_base': self.use_robot_base_check.isChecked(),
'use_depth_image': self.use_depth_image_check.isChecked(),
'frame_rate': self.frame_rate_spin.value()
}
return config
def browse_config_file(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "Select Config File", "", "YAML Files (*.yaml *.yml)"
)
if file_path:
self.config_path_edit.setText(file_path)
self.load_config()
def browse_dataset_dir(self):
dir_path = QFileDialog.getExistingDirectory(
self, "Select Dataset Directory"
)
if dir_path:
self.dataset_dir_edit.setText(dir_path)
def load_config(self):
config_path = self.config_path_edit.text()
if not os.path.exists(config_path):
QMessageBox.warning(self, "Warning", f"Config file not found: {config_path}")
return
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
self.update_ui_from_config(config)
self.statusBar().showMessage(f"Config loaded from {config_path}", 3000)
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to load config: {str(e)}")
def save_config(self):
config_path = self.config_path_edit.text()
if not config_path:
QMessageBox.warning(self, "Warning", "Please specify a config file path")
return
try:
config = self.get_config_from_ui()
with open(config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
self.statusBar().showMessage(f"Config saved to {config_path}", 3000)
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to save config: {str(e)}")
def start_recording(self):
try:
# Save current config to a temporary file
temp_config_path = "/tmp/aloha_temp_config.yaml"
config = self.get_config_from_ui()
# Validate inputs
if not config['dataset_dir']:
QMessageBox.warning(self, "Warning", "Dataset directory cannot be empty!")
return
if not config['task_name']:
QMessageBox.warning(self, "Warning", "Task name cannot be empty!")
return
with open(temp_config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
self.statusBar().showMessage("Recording started...")
# Start recording with the temporary config file
exit_code = main(temp_config_path)
if exit_code == 0:
self.statusBar().showMessage("Recording completed successfully!", 5000)
else:
self.statusBar().showMessage("Recording completed with errors", 5000)
except Exception as e:
QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}")
self.statusBar().showMessage("Recording failed", 5000)
def stop_recording(self):
# In a real application, this would signal the recording thread to stop
self.statusBar().showMessage("Recording stopped", 5000)
QMessageBox.information(self, "Info", "Stop recording requested. This would stop the recording in a real implementation.")
if __name__ == "__main__":
app = QApplication(sys.argv)
window = AlohaDataCollectionGUI()
window.show()
sys.exit(app.exec_())

View File

@@ -1,487 +0,0 @@
import logging
import time
from dataclasses import asdict
from pprint import pformat
from pprint import pprint
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.common.robot_devices.control_utils import (
# init_keyboard_listener,
record_episode,
stop_recording,
is_headless
)
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.common.utils.utils import get_safe_torch_device
from contextlib import nullcontext
from copy import copy
import torch
import rospy
import cv2
from lerobot.configs import parser
from agilex_robot import AgilexRobot
########################################################################################
# Control modes
########################################################################################
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
# if not robot.is_connected:
# robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
rate = rospy.Rate(fps)
print_flag = True
while timestamp < control_time_s and not rospy.is_shutdown():
# print(timestamp < control_time_s)
# print(rospy.is_shutdown())
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step()
if observation is None or action is None:
if print_flag:
print("sync data fail, retrying...\n")
print_flag = False
rate.sleep()
continue
else:
# pass
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# if display_cameras and not is_headless():
# image_keys = [key for key in observation if "image" in key]
# for key in image_keys:
# if "depth" in key:
# pass
# else:
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
# print(1)
# cv2.waitKey(1)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080可以根据实际情况调整
screen_width = 1920
screen_height = 1080
# 计算窗口的排列方式
num_images = len(image_keys)
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
columns = min(num_images, max_columns) # 实际使用的列数
# 遍历所有图像键并显示
for idx, key in enumerate(image_keys):
if "depth" in key:
continue # 跳过深度图像
# 将图像从 RGB 转换为 BGR 格式
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
# 创建窗口
cv2.imshow(key, image)
# 计算窗口位置
window_width = 640
window_height = 480
row = idx // max_columns
col = idx % max_columns
x_position = col * window_width
y_position = row * window_height
# 移动窗口到指定位置
cv2.moveWindow(key, x_position, y_position)
# 等待 1 毫秒以处理事件
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["record_start"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
events["record_start"] = False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.up:
print("Up arrow pressed. Start data recording...")
events["record_start"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def stop_recording(robot, listener, display_cameras):
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
fps,
single_task,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
def record(
robot,
cfg
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
# Create empty dataset or load existing saved episodes
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# policy = None
# if not robot.is_connected:
# robot.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
# if has_method(robot, "teleop_safety_stop"):
# robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
# if events["record_start"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset
def replay(
robot: AgilexRobot,
cfg,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
# if not robot.is_connected:
# robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
# log_control_info(robot, dt_s, fps=cfg.fps)
import argparse
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 100
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "move the bottle from the right to the scale right"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy = None
# args.teleoprate = True
args.control_type = "record"
# args.control_type = "replay"
return args
# @parser.wrap()
# def control_robot(cfg: ControlPipelineConfig):
# init_logging()
# logging.info(pformat(asdict(cfg)))
# # robot = make_robot_from_config(cfg.robot)
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# if isinstance(cfg.control, RecordControlConfig):
# print(cfg.control)
# record(robot, cfg.control)
# elif isinstance(cfg.control, ReplayControlConfig):
# replay(robot, cfg.control)
# # if robot.is_connected:
# # # Disconnect manually to avoid a "Core dump" during process
# # # termination due to camera threads not properly exiting.
# # robot.disconnect()
# @parser.wrap()
def control_robot(cfg):
# robot = make_robot_from_config(cfg.robot)
from agilex_robot import AgilexRobot
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
# if robot.is_connected:
# # Disconnect manually to avoid a "Core dump" during process
# # termination due to camera threads not properly exiting.
# robot.disconnect()
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# cfg = get_arguments()
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)
# capture = robot.capture_observation()
# import torch
# torch.save(capture, "test.pt")
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
# device='cpu')
# robot.send_action(action.squeeze(0))
# print()

View File

@@ -0,0 +1,292 @@
"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""
import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import os
import h5py
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
import numpy as np
import torch
import tqdm
import tyro
# 使用自定义路径覆盖
LEROBOT_HOME = Path("/home/ubuntu/hdd0/lerobot_datasets/3camera")
@dataclasses.dataclass(frozen=True)
class DatasetConfig:
use_videos: bool = True
tolerance_s: float = 0.0001
image_writer_processes: int = 10
image_writer_threads: int = 5
video_backend: str | None = None
DEFAULT_DATASET_CONFIG = DatasetConfig()
def create_empty_dataset(
repo_id: str,
robot_type: str,
mode: Literal["video", "image"] = "video",
*,
has_velocity: bool = False,
has_effort: bool = False,
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
motors = [
"right_waist",
"right_shoulder",
"right_elbow",
"right_forearm_roll",
"right_wrist_angle",
"right_wrist_rotate",
"right_gripper",
"left_waist",
"left_shoulder",
"left_elbow",
"left_forearm_roll",
"left_wrist_angle",
"left_wrist_rotate",
"left_gripper",
]
# 确定camera的情况
# cameras = [
# "cam_high",
# "cam_low",
# "cam_left_wrist",
# "cam_right_wrist",
# ]
cameras = [
"cam_high",
"cam_left_wrist",
"cam_right_wrist",
]
features = {
"observation.state": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
"action": {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
},
}
if has_velocity:
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
if has_effort:
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(motors),),
"names": [
motors,
],
}
for cam in cameras:
features[f"observation.images.{cam}"] = {
"dtype": mode,
"shape": (3, 480, 640),
"names": [
"channels",
"height",
"width",
],
}
if Path(LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
return LeRobotDataset.create(
repo_id=repo_id,
fps=30,
root=Path(LEROBOT_HOME / repo_id),
robot_type=robot_type,
features=features,
use_videos=dataset_config.use_videos,
tolerance_s=dataset_config.tolerance_s,
image_writer_processes=dataset_config.image_writer_processes,
image_writer_threads=dataset_config.image_writer_threads,
video_backend=dataset_config.video_backend,
)
def get_cameras(hdf5_files: list[Path]) -> list[str]:
with h5py.File(hdf5_files[0], "r") as ep:
# ignore depth channel, not currently handled
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
def has_velocity(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/qvel" in ep
def has_effort(hdf5_files: list[Path]) -> bool:
with h5py.File(hdf5_files[0], "r") as ep:
return "/observations/effort" in ep
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
imgs_per_cam = {}
for camera in cameras:
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
if uncompressed:
# load all images in RAM
imgs_array = ep[f"/observations/images/{camera}"][:]
else:
import cv2
# load one compressed image after the other in RAM and uncompress
imgs_array = []
for data in ep[f"/observations/images/{camera}"]:
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
imgs_array = np.array(imgs_array)
imgs_per_cam[camera] = imgs_array
return imgs_per_cam
def load_raw_episode_data(
ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
with h5py.File(ep_path, "r") as ep:
state = torch.from_numpy(ep["/observations/qpos"][:])
action = torch.from_numpy(ep["/action"][:])
velocity = None
if "/observations/qvel" in ep:
velocity = torch.from_numpy(ep["/observations/qvel"][:])
effort = None
if "/observations/effort" in ep:
effort = torch.from_numpy(ep["/observations/effort"][:])
# 确定camera的情况
# imgs_per_cam = load_raw_images_per_camera(
# ep,
# [
# "cam_high",
# "cam_low",
# "cam_left_wrist",
# "cam_right_wrist",
# ],
# )
imgs_per_cam = load_raw_images_per_camera(
ep,
[
"cam_high",
"cam_left_wrist",
"cam_right_wrist",
],
)
return imgs_per_cam, state, action, velocity, effort
def populate_dataset(
dataset: LeRobotDataset,
hdf5_files: list[Path],
task: str,
episodes: list[int] | None = None,
) -> LeRobotDataset:
if episodes is None:
episodes = range(len(hdf5_files))
for ep_idx in tqdm.tqdm(episodes):
ep_path = hdf5_files[ep_idx]
# import pdb
# pdb.set_trace()
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
num_frames = state.shape[0]
for i in range(num_frames):
frame = {
"observation.state": state[i],
"action": action[i],
}
for camera, img_array in imgs_per_cam.items():
frame[f"observation.images.{camera}"] = img_array[i]
if velocity is not None:
frame["observation.velocity"] = velocity[i]
if effort is not None:
frame["observation.effort"] = effort[i]
dataset.add_frame(frame)
dataset.save_episode(task=task)
return dataset
def port_aloha(
raw_dir: Path,
repo_id: str,
raw_repo_id: str | None = None,
task: str = "DEBUG",
*,
episodes: list[int] | None = None,
push_to_hub: bool = False,
is_mobile: bool = False,
mode: Literal["video", "image"] = "image",
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
print(LEROBOT_HOME)
if (LEROBOT_HOME / repo_id).exists():
shutil.rmtree(LEROBOT_HOME / repo_id)
if not raw_dir.exists():
if raw_repo_id is None:
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
download_raw(raw_dir, repo_id=raw_repo_id)
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
dataset = create_empty_dataset(
repo_id,
robot_type="mobile_aloha" if is_mobile else "aloha",
mode=mode,
has_effort=has_effort(hdf5_files),
has_velocity=has_velocity(hdf5_files),
dataset_config=dataset_config,
)
dataset = populate_dataset(
dataset,
hdf5_files,
task=task,
episodes=episodes,
)
dataset.consolidate()
if push_to_hub:
dataset.push_to_hub()
if __name__ == "__main__":
tyro.cli(port_aloha)

View File

@@ -1,3 +0,0 @@
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
# export LD_LIBRARY_PATH=/home/ubuntu/miniconda3/envs/lerobot/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH

View File

@@ -1,769 +0,0 @@
#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""
import torch
import numpy as np
import os
import pickle
import argparse
from einops import rearrange
import collections
from collections import deque
import rospy
from std_msgs.msg import Header
from geometry_msgs.msg import Twist
from sensor_msgs.msg import JointState, Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
import time
import threading
import math
import threading
import sys
sys.path.append("./")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']}
inference_thread = None
inference_lock = threading.Lock()
inference_actions = None
inference_timestep = None
def actions_interpolation(args, pre_action, actions, stats):
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
result = [pre_action]
post_action = post_process(actions[0])
# print("pre_action:", pre_action[7:])
# print("actions_interpolation1:", post_action[:, 7:])
max_diff_index = 0
max_diff = -1
for i in range(post_action.shape[0]):
diff = 0
for j in range(pre_action.shape[0]):
if j == 6 or j == 13:
continue
diff += math.fabs(pre_action[j] - post_action[i][j])
if diff > max_diff:
max_diff = diff
max_diff_index = i
for i in range(max_diff_index, post_action.shape[0]):
step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])])
inter = np.linspace(result[-1], post_action[i], step+2)
result.extend(inter[1:])
while len(result) < args.chunk_size+1:
result.append(result[-1])
result = np.array(result)[1:args.chunk_size+1]
# print("actions_interpolation2:", result.shape, result[:, 7:])
result = pre_process(result)
result = result[np.newaxis, :]
return result
def get_model_config(args):
# 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。
set_seed(1)
# 如果是ACT策略
# fixed parameters
if args.policy_class == 'ACT':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': args.chunk_size, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base,
'kl_weight': args.kl_weight, # kl散度权重
'hidden_dim': args.hidden_dim, # 隐藏层维度
'dim_feedforward': args.dim_feedforward,
'enc_layers': args.enc_layers,
'dec_layers': args.dec_layers,
'nheads': args.nheads,
'dropout': args.dropout,
'pre_norm': args.pre_norm
}
elif args.policy_class == 'CNNMLP':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': 1, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base
}
elif args.policy_class == 'Diffusion':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': args.chunk_size, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base,
'observation_horizon': args.observation_horizon,
'action_horizon': args.action_horizon,
'num_inference_timesteps': args.num_inference_timesteps,
'ema_power': args.ema_power
}
else:
raise NotImplementedError
config = {
'ckpt_dir': args.ckpt_dir,
'ckpt_name': args.ckpt_name,
'ckpt_stats_name': args.ckpt_stats_name,
'episode_len': args.max_publish_step,
'state_dim': args.state_dim,
'policy_class': args.policy_class,
'policy_config': policy_config,
'temporal_agg': args.temporal_agg,
'camera_names': task_config['camera_names'],
}
return config
def make_policy(policy_class, policy_config):
if policy_class == 'ACT':
policy = ACTPolicy(policy_config)
elif policy_class == 'CNNMLP':
policy = CNNMLPPolicy(policy_config)
elif policy_class == 'Diffusion':
policy = DiffusionPolicy(policy_config)
else:
raise NotImplementedError
return policy
def get_image(observation, camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def get_depth_image(observation, camera_names):
curr_images = []
for cam_name in camera_names:
curr_images.append(observation['images_depth'][cam_name])
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def inference_process(args, config, ros_operator, policy, stats, t, pre_action):
global inference_lock
global inference_actions
global inference_timestep
print_flag = True
pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"]
rate = rospy.Rate(args.publish_rate)
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail")
print_flag = False
rate.sleep()
continue
print_flag = True
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base) = result
obs = collections.OrderedDict()
image_dict = dict()
image_dict[config['camera_names'][0]] = img_front
image_dict[config['camera_names'][1]] = img_left
image_dict[config['camera_names'][2]] = img_right
obs['images'] = image_dict
if args.use_depth_image:
image_depth_dict = dict()
image_depth_dict[config['camera_names'][0]] = img_front_depth
image_depth_dict[config['camera_names'][1]] = img_left_depth
image_depth_dict[config['camera_names'][2]] = img_right_depth
obs['images_depth'] = image_depth_dict
obs['qpos'] = np.concatenate(
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
obs['qvel'] = np.concatenate(
(np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
obs['effort'] = np.concatenate(
(np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
if args.use_robot_base:
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0)
else:
obs['base_vel'] = [0.0, 0.0]
# qpos_numpy = np.array(obs['qpos'])
# 归一化处理qpos 并转到cuda
qpos = pre_pos_process(obs['qpos'])
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
# 当前图像curr_image获取图像
curr_image = get_image(obs, config['camera_names'])
curr_depth_image = None
if args.use_depth_image:
curr_depth_image = get_depth_image(obs, config['camera_names'])
start_time = time.time()
all_actions = policy(curr_image, curr_depth_image, qpos)
end_time = time.time()
print("model cost time: ", end_time -start_time)
inference_lock.acquire()
inference_actions = all_actions.cpu().detach().numpy()
if pre_action is None:
pre_action = obs['qpos']
# print("obs['qpos']:", obs['qpos'][7:])
if args.use_actions_interpolation:
inference_actions = actions_interpolation(args, pre_action, inference_actions, stats)
inference_timestep = t
inference_lock.release()
break
def model_inference(args, config, ros_operator, save_episode=True):
global inference_lock
global inference_actions
global inference_timestep
global inference_thread
set_seed(1000)
# 1 创建模型数据 继承nn.Module
policy = make_policy(config['policy_class'], config['policy_config'])
# print("model structure\n", policy.model)
# 2 加载模型权重
ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name'])
state_dict = torch.load(ckpt_path)
new_state_dict = {}
for key, value in state_dict.items():
if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]:
continue
if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]:
continue
new_state_dict[key] = value
loading_status = policy.deserialize(new_state_dict)
if not loading_status:
print("ckpt path not exist")
return False
# 3 模型设置为cuda模式和验证模式
policy.cuda()
policy.eval()
# 4 加载统计值
stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name'])
# 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 数据预处理和后处理函数定义
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
max_publish_step = config['episode_len']
chunk_size = config['policy_config']['chunk_size']
# 发布基础的姿态
left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
ros_operator.puppet_arm_publish_continuous(left0, right0)
input("Enter any key to continue :")
ros_operator.puppet_arm_publish_continuous(left1, right1)
action = None
# 推理
with torch.inference_mode():
while True and not rospy.is_shutdown():
# 每个回合的步数
t = 0
max_t = 0
rate = rospy.Rate(args.publish_rate)
if config['temporal_agg']:
all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']])
while t < max_publish_step and not rospy.is_shutdown():
# start_time = time.time()
# query policy
if config['policy_class'] == "ACT":
if t >= max_t:
pre_action = action
inference_thread = threading.Thread(target=inference_process,
args=(args, config, ros_operator,
policy, stats, t, pre_action))
inference_thread.start()
inference_thread.join()
inference_lock.acquire()
if inference_actions is not None:
inference_thread = None
all_actions = inference_actions
inference_actions = None
max_t = t + args.pos_lookahead_step
if config['temporal_agg']:
all_time_actions[[t], t:t + chunk_size] = all_actions
inference_lock.release()
if config['temporal_agg']:
actions_for_curr_step = all_time_actions[:, t]
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = exp_weights[:, np.newaxis]
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True)
else:
if args.pos_lookahead_step != 0:
raw_action = all_actions[:, t % args.pos_lookahead_step]
else:
raw_action = all_actions[:, t % chunk_size]
else:
raise NotImplementedError
action = post_process(raw_action[0])
left_action = action[:7] # 取7维度
right_action = action[7:14]
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
if args.use_robot_base:
vel_action = action[14:16]
ros_operator.robot_base_publish(vel_action)
t += 1
# end_time = time.time()
# print("publish: ", t)
# print("time:", end_time - start_time)
# print("left_action:", left_action)
# print("right_action:", right_action)
rate.sleep()
class RosOperator:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.puppet_arm_left_publisher = None
self.puppet_arm_right_publisher = None
self.robot_base_publisher = None
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_lock = None
self.args = args
self.ctrl_state = False
self.ctrl_state_lock = threading.Lock()
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
self.puppet_arm_publish_lock = threading.Lock()
self.puppet_arm_publish_lock.acquire()
def puppet_arm_publish(self, left, right):
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right
self.puppet_arm_right_publisher.publish(joint_state_msg)
def robot_base_publish(self, vel):
vel_msg = Twist()
vel_msg.linear.x = vel[0]
vel_msg.linear.y = 0
vel_msg.linear.z = 0
vel_msg.angular.x = 0
vel_msg.angular.y = 0
vel_msg.angular.z = vel[1]
self.robot_base_publisher.publish(vel_msg)
def puppet_arm_publish_continuous(self, left, right):
rate = rospy.Rate(self.args.publish_rate)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
flag = True
step = 0
while flag and not rospy.is_shutdown():
if self.puppet_arm_publish_lock.acquire(False):
return
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
flag = False
for i in range(len(left)):
if left_diff[i] < self.args.arm_steps_length[i]:
left_arm[i] = left[i]
else:
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
flag = True
for i in range(len(right)):
if right_diff[i] < self.args.arm_steps_length[i]:
right_arm[i] = right[i]
else:
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
flag = True
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left_arm
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right_arm
self.puppet_arm_right_publisher.publish(joint_state_msg)
step += 1
print("puppet_arm_publish_continuous:", step)
rate.sleep()
def puppet_arm_publish_linear(self, left, right):
num_step = 100
rate = rospy.Rate(200)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
traj_left_list = np.linspace(left_arm, left, num_step)
traj_right_list = np.linspace(right_arm, right, num_step)
for i in range(len(traj_left_list)):
traj_left = traj_left_list[i]
traj_right = traj_right_list[i]
traj_left[-1] = left[-1]
traj_right[-1] = right[-1]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = traj_left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = traj_right
self.puppet_arm_right_publisher.publish(joint_state_msg)
rate.sleep()
def puppet_arm_publish_continuous_thread(self, left, right):
if self.puppet_arm_publish_thread is not None:
self.puppet_arm_publish_lock.release()
self.puppet_arm_publish_thread.join()
self.puppet_arm_publish_lock.acquire(False)
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
self.puppet_arm_publish_thread.start()
def get_frame(self):
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def ctrl_callback(self, msg):
self.ctrl_state_lock.acquire()
self.ctrl_state = msg.data
self.ctrl_state_lock.release()
def get_ctrl_state(self):
self.ctrl_state_lock.acquire()
state = self.ctrl_state
self.ctrl_state_lock.release()
return state
def init_ros(self):
rospy.init_node('joint_state_publisher', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False)
parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False)
parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False)
parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False)
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False)
parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False)
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False)
parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False)
parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False)
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False)
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features", required=False)
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False)
parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False)
parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False)
parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False)
parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False)
parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False)
parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False)
parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False)
parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False)
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False)
parser.add_argument('--pre_norm', action='store_true', required=False)
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
default='/camera_f/depth/image_raw', required=False)
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
default='/camera_l/depth/image_raw', required=False)
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
default='/camera_r/depth/image_raw', required=False)
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
default='/master/joint_left', required=False)
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
default='/master/joint_right', required=False)
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/odom_raw', required=False)
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
default='/cmd_vel', required=False)
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
default=False, required=False)
parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate',
default=40, required=False)
parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step',
default=0, required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size',
default=32, required=False)
parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length',
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation',
default=False, required=False)
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
default=False, required=False)
# for Diffusion
parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False)
parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False)
parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False)
parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False)
args = parser.parse_args()
return args
def main():
args = get_arguments()
ros_operator = RosOperator(args)
config = get_model_config(args)
model_inference(args, config, ros_operator, save_episode=True)
if __name__ == '__main__':
main()
# python act/inference.py --ckpt_dir ~/train0314/

207
collect_data/replay_data.py Normal file → Executable file
View File

@@ -10,24 +10,63 @@ from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from geometry_msgs.msg import Twist
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import sys
sys.path.append("./")
def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
if not os.path.isfile(dataset_path):
print(f'Dataset does not exist at \n{dataset_path}\n')
exit()
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
compressed = root.attrs.get('compress', False)
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
if 'effort' in root.keys():
effort = root['/observations/effort'][()]
else:
effort = None
action = root['/action'][()]
base_action = root['/base_action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
if compressed:
compress_len = root['/compress_len'][()]
if compressed:
for cam_id, cam_name in enumerate(image_dict.keys()):
# un-pad and uncompress
padded_compressed_image_list = image_dict[cam_name]
image_list = []
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
image_len = int(compress_len[cam_id, frame_id])
compressed_image = padded_compressed_image
image = cv2.imdecode(compressed_image, 1)
image_list.append(image)
image_dict[cam_name] = image_list
return qpos, qvel, effort, action, base_action, image_dict
def main(args):
rospy.init_node("replay_node")
bridge = CvBridge()
# img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
# img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
# img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
# puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
# puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
# robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
# dataset_dir = args.dataset_dir
@@ -35,78 +74,130 @@ def main(args):
# task_name = args.task_name
# dataset_name = f'episode_{episode_idx}'
dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode])
actions = dataset.hf_dataset.select_columns("action")
velocitys = dataset.hf_dataset.select_columns("observation.velocity")
efforts = dataset.hf_dataset.select_columns("observation.effort")
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
twist_msg = Twist()
rate = rospy.Rate(args.fps)
rate = rospy.Rate(args.frame_rate)
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
actions = [[ 0.1277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.3238, -0.1094,
0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]]
last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647]
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
rate = rospy.Rate(50)
for idx in range(len(actions)):
action = actions[idx]['action'].detach().cpu().numpy()
velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy()
effort = efforts[idx]['observation.effort'].detach().cpu().numpy()
if(rospy.is_shutdown()):
break
new_actions = np.linspace(last_action, action, 5) # 插值
new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值
new_efforts = np.linspace(last_effort, effort, 5) # 插值
last_action = action
last_velocity = velocity
last_effort = effort
for act in new_actions:
print(np.round(act[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
if not args.only_pub_master:
last_action = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279, 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
rate = rospy.Rate(100)
for action in actions:
if(rospy.is_shutdown()):
break
joint_state_msg.position = act[:7]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
new_actions = np.linspace(last_action, action, 50) # 插值
last_action = action
for act in new_actions:
print(np.round(act[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = act[:7]
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = act[7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
else:
i = 0
while(not rospy.is_shutdown() and i < len(actions)):
print("left: ", np.round(qposs[i][:7], 4), " right: ", np.round(qposs[i][7:], 4))
cam_names = [k for k in image_dicts.keys()]
image0 = image_dicts[cam_names[0]][i]
image0 = image0[:, :, [2, 1, 0]] # swap B and R channel
image1 = image_dicts[cam_names[1]][i]
image1 = image1[:, :, [2, 1, 0]] # swap B and R channel
image2 = image_dicts[cam_names[2]][i]
image2 = image2[:, :, [2, 1, 0]] # swap B and R channel
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = actions[i][:7]
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = act[7:]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[7:]
joint_state_msg.position = actions[i][7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
joint_state_msg.position = qposs[i][:7]
puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = qposs[i][7:]
puppet_arm_right_publisher.publish(joint_state_msg)
img_front_publisher.publish(bridge.cv2_to_imgmsg(image0, "bgr8"))
img_left_publisher.publish(bridge.cv2_to_imgmsg(image1, "bgr8"))
img_right_publisher.publish(bridge.cv2_to_imgmsg(image2, "bgr8"))
twist_msg.linear.x = base_actions[i][0]
twist_msg.angular.z = base_actions[i][1]
robot_base_publisher.publish(twist_msg)
i += 1
rate.sleep()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
# default='/master/joint_left', required=False)
# parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
# default='/master/joint_right', required=False)
args = parser.parse_args()
args.repo_id = "tangger/test"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
args.episode = 1 # replay episode
args.master_arm_left_topic = "/master/joint_left"
args.master_arm_right_topic = "/master/joint_right"
args.fps = 30
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=False)
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
default="aloha_mobile_dummy", required=False)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
parser.add_argument('--camera_names', action='store', type=str, help='camera_names',
default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False)
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
default='/master/joint_left', required=False)
parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
default='/master/joint_right', required=False)
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/cmd_vel', required=False)
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
default=False, required=False)
parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate',
default=30, required=False)
parser.add_argument('--only_pub_master', action='store_true', help='only_pub_master',required=False)
args = parser.parse_args()
main(args)
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0

10
collect_data/requirements.txt Executable file
View File

@@ -0,0 +1,10 @@
opencv-python==4.9.0.80
matplotlib==3.7.5
h5py==3.8.0
dm-env==1.6
numpy==1.23.4
pyyaml
rospkg==1.5.0
catkin-pkg==1.0.0
empy==3.3.4
PyQt5==5.15.10

View File

@@ -1,372 +0,0 @@
import yaml
from collections import deque
import rospy
from cv_bridge import CvBridge
from typing import Dict, Any, Optional, List
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
import argparse
class Robot:
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
"""
机器人基类,处理通用初始化逻辑
Args:
config_file: YAML配置文件路径
args: 运行时参数
"""
self._load_config(config_file)
self._merge_runtime_args(args)
self._init_components()
self._init_data_structures()
self.init_ros()
self.init_features()
self.warmup()
def _load_config(self, config_file: str) -> None:
"""加载YAML配置文件"""
with open(config_file, 'r') as f:
self.config = yaml.safe_load(f)
def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
"""合并运行时参数到配置"""
if args is None:
return
runtime_params = {
'frame_rate': getattr(args, 'fps', None),
'max_timesteps': getattr(args, 'max_timesteps', None),
'episode_idx': getattr(args, 'episode_idx', None),
'use_depth_image': getattr(args, 'use_depth_image', None),
'use_robot_base': getattr(args, 'use_base', None),
'video': getattr(args, 'video', False),
'control_type': getattr(args, 'control_type', False),
}
for key, value in runtime_params.items():
if value is not None:
self.config[key] = value
def _init_components(self) -> None:
"""初始化核心组件"""
self.bridge = CvBridge()
self.subscribers = {}
self.publishers = {}
self._validate_config()
def _validate_config(self) -> None:
"""验证配置完整性"""
required_sections = ['cameras', 'arm']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required config section: {section}")
def _init_data_structures(self) -> None:
"""初始化数据结构模板方法"""
# 相机数据
self.cameras = self.config.get('cameras', {})
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
# 深度数据
self.use_depth_image = self.config.get('use_depth_image', False)
if self.use_depth_image:
self.sync_depth_queues = {
name: deque(maxlen=2000)
for name, cam in self.cameras.items()
if 'depth_topic_name' in cam
}
# 机械臂数据
self.arms = self.config.get('arm', {})
if self.config.get('control_type', '') != 'record':
# 如果不是录制模式,则仅初始化从机械臂数据队列
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
else:
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
# 机器人基座数据
self.use_robot_base = self.config.get('use_robot_base', False)
if self.use_robot_base:
self.sync_base_queue = deque(maxlen=2000)
def init_ros(self) -> None:
"""初始化ROS订阅的模板方法"""
rospy.init_node(
f"{self.config.get('ros_node_name', 'generic_robot_node')}",
anonymous=True
)
self._setup_camera_subscribers()
self._setup_arm_subscribers_publishers()
self._setup_base_subscriber()
self._log_ros_status()
def init_features(self):
"""
根据YAML配置自动生成features结构
"""
self.features = {}
# 初始化相机特征
self._init_camera_features()
# 初始化机械臂特征
self._init_state_features()
self._init_action_features()
# 初始化基座特征(如果启用)
if self.use_robot_base:
self._init_base_features()
import pprint
pprint.pprint(self.features, indent=4)
def _init_camera_features(self):
"""处理所有相机特征"""
for cam_name, cam_config in self.cameras.items():
# 普通图像
self.features[f"observation.images.{cam_name}"] = {
"dtype": "video" if self.config.get("video", False) else "image",
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
"names": ["height", "width", "channel"],
# "video_info": {
# "video.fps": cam_config.get("fps", 30.0),
# "video.codec": cam_config.get("codec", "av1"),
# "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"),
# "video.is_depth_map": False,
# "has_audio": False
# }
}
if self.config.get("use_depth_image", False):
self.features[f"observation.images.depth_{cam_name}"] = {
"dtype": "uint16",
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
"names": ["height", "width"],
}
def _init_state_features(self):
state = self.config.get('state', {})
# 状态特征
self.features["observation.state"] = {
"dtype": "float32",
"shape": (len(state.get('motors', "")),),
"names": {"motors": state.get('motors', "")}
}
if self.config.get('velocity'):
velocity = self.config.get('velocity', "")
self.features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(velocity.get('motors', "")),),
"names": {"motors": velocity.get('motors', "")}
}
if self.config.get('effort'):
effort = self.config.get('effort', "")
self.features["observation.effort"] = {
"dtype": "float32",
"shape": (len(effort.get('motors', "")),),
"names": {"motors": effort.get('motors', "")}
}
def _init_action_features(self):
action = self.config.get('action', {})
# 状态特征
self.features["action"] = {
"dtype": "float32",
"shape": (len(action.get('motors', "")),),
"names": {"motors": action.get('motors', "")}
}
def _init_base_features(self):
"""处理基座特征"""
self.features["observation.base_vel"] = {
"dtype": "float32",
"shape": (2,),
"names": ["linear_x", "angular_z"]
}
def _setup_camera_subscribers(self) -> None:
"""设置相机订阅者"""
for cam_name, cam_config in self.cameras.items():
if 'img_topic_name' in cam_config:
self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber(
cam_config['img_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=False),
queue_size=1000,
tcp_nodelay=True
)
if self.use_depth_image and 'depth_topic_name' in cam_config:
self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber(
cam_config['depth_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=True),
queue_size=1000,
tcp_nodelay=True
)
def _setup_arm_subscribers_publishers(self) -> None:
"""设置机械臂订阅者"""
# 当为record模式时,主从机械臂都需要订阅
# 否则只订阅从机械臂,但向主机械臂发布
if self.config.get('control_type', '') == 'record':
for arm_name, arm_config in self.arms.items():
if 'topic_name' in arm_config:
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name),
queue_size=1000,
tcp_nodelay=True
)
else:
for arm_name, arm_config in self.arms.items():
if 'puppet' in arm_name:
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name),
queue_size=1000,
tcp_nodelay=True
)
if 'master' in arm_name:
self.publishers[f"arm_{arm_name}"] = rospy.Publisher(
arm_config['topic_name'],
JointState,
queue_size=10
)
def _setup_base_subscriber(self) -> None:
"""设置基座订阅者"""
if self.use_robot_base and 'robot_base' in self.config:
self.subscribers['base'] = rospy.Subscriber(
self.config['robot_base']['topic_name'],
Odometry,
self.robot_base_callback,
queue_size=1000,
tcp_nodelay=True
)
def _log_ros_status(self) -> None:
"""记录ROS状态"""
rospy.loginfo("\n=== ROS订阅状态 ===")
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
rospy.loginfo("活跃的订阅者:")
for topic, sub in self.subscribers.items():
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
rospy.loginfo("=================")
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
"""生成相机回调函数工厂方法"""
def callback(msg):
try:
target_queue = (
self.sync_depth_queues[cam_name]
if is_depth
else self.sync_img_queues[cam_name]
)
if len(target_queue) >= 2000:
target_queue.popleft()
target_queue.append(msg)
except Exception as e:
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
return callback
def _make_arm_callback(self, arm_name: str):
"""生成机械臂回调函数工厂方法"""
def callback(msg):
try:
if len(self.sync_arm_queues[arm_name]) >= 2000:
self.sync_arm_queues[arm_name].popleft()
self.sync_arm_queues[arm_name].append(msg)
except Exception as e:
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
return callback
def robot_base_callback(self, msg):
"""基座回调默认实现"""
if len(self.sync_base_queue) >= 2000:
self.sync_base_queue.popleft()
self.sync_base_queue.append(msg)
def warmup(self, timeout: float = 10.0) -> bool:
"""Wait until all data queues have at least 20 messages.
Args:
timeout: Maximum time to wait in seconds before giving up
Returns:
bool: True if warmup succeeded, False if timed out
"""
start_time = rospy.Time.now().to_sec()
rate = rospy.Rate(10) # Check at 10Hz
rospy.loginfo("Starting warmup process...")
while not rospy.is_shutdown():
# Check if timeout has been reached
current_time = rospy.Time.now().to_sec()
if current_time - start_time > timeout:
rospy.logwarn("Warmup timed out before all queues were filled")
return False
# Check all required queues
all_ready = True
# Check camera image queues
for cam_name in self.cameras:
if len(self.sync_img_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)")
all_ready = False
break
# Check depth queues if enabled
if self.use_depth_image:
for cam_name in self.sync_depth_queues:
if len(self.sync_depth_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)")
all_ready = False
break
# # Check arm queues
# for arm_name in self.arms:
# if len(self.sync_arm_queues[arm_name]) < 20:
# rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)")
# all_ready = False
# break
# Check base queue if enabled
if self.use_robot_base:
if len(self.sync_base_queue) < 20:
rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)")
all_ready = False
# If all queues are ready, return success
if all_ready:
rospy.loginfo("Warmup completed successfully")
return True
rate.sleep()
return False
def get_frame(self) -> Optional[Dict[str, Any]]:
"""获取同步帧数据的模板方法"""
raise NotImplementedError("Subclasses must implement get_frame()")
def process(self) -> tuple:
"""主处理循环的模板方法"""
raise NotImplementedError("Subclasses must implement process()")

View File

@@ -1,26 +0,0 @@
import yaml
import argparse
from typing import Dict, List, Any, Optional
from rosrobot import Robot
from agilex_robot import AgilexRobot
class RobotFactory:
@staticmethod
def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
"""
根据配置文件自动创建合适的机器人实例
Args:
config_file: 配置文件路径
args: 运行时参数
"""
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
robot_type = config.get('robot_type', 'agilex')
if robot_type == 'agilex':
return AgilexRobot(config_file, args)
# 可扩展其他机器人类型
else:
raise ValueError(f"Unsupported robot type: {robot_type}")

322
collect_data/utils.py Normal file
View File

@@ -0,0 +1,322 @@
import cv2
import numpy as np
import h5py
import time
def display_camera_grid(image_dict, grid_shape=None, window_name="MindRobot-V1 Data Collection", scale=1.0):
"""
显示多摄像头画面(保持原始比例,但可整体缩放)
参数:
image_dict: {摄像头名称: 图像numpy数组}
grid_shape: (行, 列) 布局None自动计算
window_name: 窗口名称
scale: 整体显示缩放比例0.5表示显示为原尺寸的50%
"""
# 输入验证和数据处理(保持原代码不变)
if not isinstance(image_dict, dict):
raise TypeError("输入必须是字典类型")
valid_data = []
for name, img in image_dict.items():
if not isinstance(img, np.ndarray):
continue
if img.dtype != np.uint8:
img = img.astype(np.uint8)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
elif img.shape[2] == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
valid_data.append((name, img))
if not valid_data:
print("错误: 没有有效的图像可显示!")
return None
# 自动计算网格布局
num_valid = len(valid_data)
if grid_shape is None:
grid_shape = (1, num_valid) if num_valid <= 3 else (2, int(np.ceil(num_valid/2)))
rows, cols = grid_shape
# 计算每行/列的最大尺寸
row_heights = [0]*rows
col_widths = [0]*cols
for i, (_, img) in enumerate(valid_data[:rows*cols]):
r, c = i//cols, i%cols
row_heights[r] = max(row_heights[r], img.shape[0])
col_widths[c] = max(col_widths[c], img.shape[1])
# 计算画布总尺寸(应用整体缩放)
canvas_h = int(sum(row_heights) * scale)
canvas_w = int(sum(col_widths) * scale)
# 创建画布
canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)
# 计算每个子画面的显示区域
row_pos = [0] + [int(sum(row_heights[:i+1])*scale) for i in range(rows)]
col_pos = [0] + [int(sum(col_widths[:i+1])*scale) for i in range(cols)]
# 填充图像
for i, (name, img) in enumerate(valid_data[:rows*cols]):
r, c = i//cols, i%cols
# 计算当前图像的显示区域
x1, x2 = col_pos[c], col_pos[c+1]
y1, y2 = row_pos[r], row_pos[r+1]
# 计算当前图像的缩放后尺寸
display_h = int(img.shape[0] * scale)
display_w = int(img.shape[1] * scale)
# 缩放图像(保持比例)
resized_img = cv2.resize(img, (display_w, display_h))
# 放置到画布
canvas[y1:y1+display_h, x1:x1+display_w] = resized_img
# 添加标签(按比例缩放字体)
font_scale = 0.8 *scale
thickness = max(2, int(2 * scale))
cv2.putText(canvas, name, (x1+10, y1+30),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness)
# 显示窗口(自动适应屏幕)
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
cv2.imshow(window_name, canvas)
cv2.resizeWindow(window_name, canvas_w, canvas_h)
cv2.waitKey(1)
return canvas
# 保存数据函数
def save_data(args, timesteps, actions, dataset_path):
# 数据字典
data_size = len(actions)
data_dict = {
# 一个是奖励里面的qposqvel effort ,一个是实际发的acition
'/observations/qpos': [],
'/observations/qvel': [],
'/observations/effort': [],
'/action': [],
'/base_action': [],
# '/base_action_t265': [],
}
# 相机字典 观察的图像
for cam_name in args.camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
if args.use_depth_image:
data_dict[f'/observations/images_depth/{cam_name}'] = []
# len(action): max_timesteps, len(time_steps): max_timesteps + 1
# 动作长度 遍历动作
while actions:
# 循环弹出一个队列
action = actions.pop(0) # 动作 当前动作
ts = timesteps.pop(0) # 奖励 前一帧
# 往字典里面添值
# Timestep返回的qposqvel,effort
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/observations/effort'].append(ts.observation['effort'])
# 实际发的action
data_dict['/action'].append(action)
data_dict['/base_action'].append(ts.observation['base_vel'])
# 相机数据
# data_dict['/base_action_t265'].append(ts.observation['base_vel_t265'])
for cam_name in args.camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
if args.use_depth_image:
data_dict[f'/observations/images_depth/{cam_name}'].append(ts.observation['images_depth'][cam_name])
t0 = time.time()
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
# 文本的属性:
# 1 是否仿真
# 2 图像是否压缩
#
root.attrs['sim'] = False
root.attrs['compress'] = False
# 创建一个新的组observations观测状态组
# 图像组
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in args.camera_names:
_ = image.create_dataset(cam_name, (data_size, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
if args.use_depth_image:
image_depth = obs.create_group('images_depth')
for cam_name in args.camera_names:
_ = image_depth.create_dataset(cam_name, (data_size, 480, 640), dtype='uint16',
chunks=(1, 480, 640), )
_ = obs.create_dataset('qpos', (data_size, 14))
_ = obs.create_dataset('qvel', (data_size, 14))
_ = obs.create_dataset('effort', (data_size, 14))
_ = root.create_dataset('action', (data_size, 14))
_ = root.create_dataset('base_action', (data_size, 2))
# data_dict write into h5py.File
for name, array in data_dict.items():
root[name][...] = array
print(f'\033[32m\nSaving: {time.time() - t0:.1f} secs. %s \033[0m\n'%dataset_path)
def is_headless():
"""
Check if the environment is headless (no display available).
Returns:
bool: True if the environment is headless, False otherwise.
"""
try:
import tkinter as tk
root = tk.Tk()
root.withdraw()
root.update()
root.destroy()
return False
except:
return True
def init_keyboard_listener():
"""
Initialize keyboard listener for control events with new key mappings:
- Left arrow: Start data recording
- Right arrow: Save current data
- Down arrow: Discard current data
- Up arrow: Replay current data
- ESC: Early termination
Returns:
tuple: (listener, events) - Keyboard listener and events dictionary
"""
events = {
"exit_early": False,
"record_start": False,
"save_data": False,
"discard_data": False,
"replay_data": False
}
if is_headless():
print(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return None, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.left:
print("← Left arrow: STARTING data recording...")
events.update({
"record_start": True,
"exit_early": False,
"save_data": False,
"discard_data": False
})
elif key == keyboard.Key.right:
print("→ Right arrow: SAVING current data...")
events.update({
"save_data": True,
"exit_early": False,
"record_start": False
})
elif key == keyboard.Key.down:
print("↓ Down arrow: DISCARDING current data...")
events.update({
"discard_data": True,
"exit_early": False,
"record_start": False
})
elif key == keyboard.Key.up:
print("↑ Up arrow: REPLAYING current data...")
events.update({
"replay_data": True,
"exit_early": False
})
elif key == keyboard.Key.esc:
print("ESC: EARLY TERMINATION requested")
events.update({
"exit_early": True,
"record_start": False
})
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
import yaml
from argparse import Namespace
def load_config(yaml_path):
"""Load configuration from YAML file and return as Namespace object"""
with open(yaml_path, 'r') as f:
config_dict = yaml.safe_load(f)
# Convert dict to Namespace (similar to argparse.Namespace)
return Namespace(**config_dict)
import platform
import subprocess
# import pyttsx3
def say(text, blocking=False):
system = platform.system()
if system == "Darwin":
cmd = ["say", text]
elif system == "Linux":
# cmd = ["spd-say", text]
# if blocking:
# cmd.append("--wait")
cmd = ["edge-playback", "--text", text]
elif system == "Windows":
cmd = [
"PowerShell",
"-Command",
"Add-Type -AssemblyName System.Speech; "
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
]
else:
raise RuntimeError("Unsupported operating system for text-to-speech.")
if blocking:
subprocess.run(cmd, check=True)
else:
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
def log_say(text, play_sounds, blocking=False):
print(text)
if play_sounds:
say(text, blocking)

View File

@@ -0,0 +1,159 @@
#coding=utf-8
import os
import numpy as np
import cv2
import h5py
import argparse
import matplotlib.pyplot as plt
DT = 0.02
# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
JOINT_NAMES = ["joint0", "joint1", "joint2", "joint3", "joint4", "joint5"]
STATE_NAMES = JOINT_NAMES + ["gripper"]
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]
def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
if not os.path.isfile(dataset_path):
print(f'Dataset does not exist at \n{dataset_path}\n')
exit()
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
compressed = root.attrs.get('compress', False)
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
if 'effort' in root.keys():
effort = root['/observations/effort'][()]
else:
effort = None
action = root['/action'][()]
base_action = root['/base_action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
if compressed:
compress_len = root['/compress_len'][()]
if compressed:
for cam_id, cam_name in enumerate(image_dict.keys()):
# un-pad and uncompress
padded_compressed_image_list = image_dict[cam_name]
image_list = []
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
image_len = int(compress_len[cam_id, frame_id])
compressed_image = padded_compressed_image
image = cv2.imdecode(compressed_image, 1)
image_list.append(image)
image_dict[cam_name] = image_list
return qpos, qvel, effort, action, base_action, image_dict
def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
task_name = args['task_name']
dataset_name = f'episode_{episode_idx}'
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
print('hdf5 loaded!!')
save_videos(image_dict, action, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
def save_videos(video, actions, dt, video_path=None):
cam_names = list(video.keys())
all_cam_videos = []
for cam_name in cam_names:
all_cam_videos.append(video[cam_name])
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
image = all_cam_videos[t]
image = image[:, :, [2, 1, 0]] # swap B and R channel
cv2.imshow("images",image)
cv2.waitKey(30)
print("episode_id: ", t, "left: ", np.round(actions[t][:7], 3), "right: ", np.round(actions[t][7:], 3), "\n")
out.write(image)
out.release()
print(f'Saved video to: {video_path}')
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
if label_overwrite:
label1, label2 = label_overwrite
else:
label1, label2 = 'State', 'Command'
qpos = np.array(qpos_list) # ts, dim
command = np.array(command_list)
num_ts, num_dim = qpos.shape
h, w = 2, num_dim
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
# plot joint state
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(qpos[:, dim_idx], label=label1, color='orangered')
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
# plot arm command
# for dim_idx in range(num_dim):
# ax = axs[dim_idx]
# ax.plot(command[:, dim_idx], label=label2)
# ax.legend()
if ylim:
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved qpos plot to: {plot_path}')
plt.close()
def visualize_base(readings, plot_path=None):
readings = np.array(readings) # ts, dim
num_ts, num_dim = readings.shape
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
# plot joint state
all_names = BASE_STATE_NAMES
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(readings[:, dim_idx], label='raw')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved effort plot to: {plot_path}')
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
default="aloha_mobile_dummy", required=False)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
main(vars(parser.parse_args()))

1
lerobot Submodule

Submodule lerobot added at 1c873df5c0

18
lerobot_aloha/README.MD Normal file
View File

@@ -0,0 +1,18 @@
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
# fd token
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
# act
python lerobot/lerobot/scripts/train.py --policy.type=act --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_the_scale_with_head_camera/ --dataset.repo_id=maic/move_tube_on_scale_head --job_name=act_with_head --output_dir=outputs/train/act_move_bottle_on_scale_with_head
python lerobot/lerobot/scripts/visualize_dataset_html.py --root /home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_a_scale_without_head_camera --repo-id xxx
# pi0 ft
python lerobot/lerobot/scripts/train.py \
--policy.path=lerobot/pi0 \
--wandb.enable=true \
--dataset.root=/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_a_scale_without_head_camera \
--dataset.repo_id=maic/move_a_reagent_bottle_on_a_scale_without_head_camera \
--job_name=pi0_without_head \
--output_dir=outputs/train/move_a_reagent_bottle_on_a_scale_without_head_camera

Binary file not shown.

View File

@@ -1,17 +1,13 @@
import yaml
import cv2
import numpy as np
import collections
import dm_env
import argparse
from typing import Dict, List, Any, Optional
from collections import deque
import rospy
from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
from rosrobot import Robot
from sensor_msgs.msg import JointState
from .rosrobot import Robot
import torch
import time
@@ -40,10 +36,15 @@ class AgilexRobot(Robot):
# print("can not get data from puppet topic")
# return None
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
print("can not get data from puppet topic")
return None
# 检查必要的机械臂数据是否可用
required_arms = ['puppet_left', 'puppet_right']
for arm_name in required_arms:
if arm_name not in self.sync_arm_queues or len(self.sync_arm_queues[arm_name]) == 0:
print(f"can not get data from {arm_name} topic")
return None
# 时间戳误差
tolerance = 0.1 # 允许 100ms 的时间戳偏差
# 计算最小时间戳
timestamps = [
q[-1].header.stamp.to_sec()
@@ -59,18 +60,16 @@ class AgilexRobot(Robot):
min_time = min(timestamps)
# 检查数据同步性
# 检查数据同步性(允许 100ms 偏差)
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
if queue[-1].header.stamp.to_sec() < min_time:
if queue[-1].header.stamp.to_sec() < min_time - tolerance:
return None
if self.use_depth_image:
for queue in self.sync_depth_queues.values():
if queue[-1].header.stamp.to_sec() < min_time:
if queue[-1].header.stamp.to_sec() < min_time - tolerance:
return None
if self.use_robot_base and len(self.sync_base_queue) > 0:
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time - tolerance:
return None
# 提取同步数据
@@ -82,33 +81,35 @@ class AgilexRobot(Robot):
# 图像数据
for cam_name, queue in self.sync_img_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
queue.popleft()
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
if queue:
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
# 深度数据
if self.use_depth_image:
frame_data['depths'] = {}
for cam_name, queue in self.sync_depth_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
queue.popleft()
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
# 保持原有的边界填充
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
)
if queue:
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
)
# 机械臂数据
for arm_name, queue in self.sync_arm_queues.items():
while queue[0].header.stamp.to_sec() < min_time:
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
queue.popleft()
frame_data['arms'][arm_name] = queue.popleft()
if queue:
frame_data['arms'][arm_name] = queue.popleft()
# 基座数据
if self.use_robot_base and len(self.sync_base_queue) > 0:
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
while self.sync_base_queue and self.sync_base_queue[0].header.stamp.to_sec() < min_time - tolerance:
self.sync_base_queue.popleft()
frame_data['base'] = self.sync_base_queue.popleft()
if self.sync_base_queue:
frame_data['base'] = self.sync_base_queue.popleft()
return frame_data
@@ -126,7 +127,12 @@ class AgilexRobot(Robot):
return None, None
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
# 遍历字典,检查并报告每个空队列
for arm_name, queue in self.sync_arm_queues.items():
if len(queue) == 0:
print(f"{arm_name} arm not connected or queue is empty")
return None, None
# 计算最小时间戳
timestamps = [
@@ -206,11 +212,11 @@ class AgilexRobot(Robot):
if arm_states:
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
if arm_velocity:
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
# if arm_velocity:
# obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
if arm_effort:
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
# if arm_effort:
# obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
if actions:
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
@@ -272,7 +278,7 @@ class AgilexRobot(Robot):
# Log timing information
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
# print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
# Combine all arm states into single tensor
if arm_states:
@@ -295,7 +301,7 @@ class AgilexRobot(Robot):
# Log timing information
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
# print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
# Process depth data if enabled
if self.use_depth_image and 'depths' in frame_data:
@@ -307,7 +313,7 @@ class AgilexRobot(Robot):
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
# print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
# Process base velocity if enabled
if self.use_robot_base and 'base' in frame_data:
@@ -330,12 +336,18 @@ class AgilexRobot(Robot):
Returns:
The actual action that was sent (may be clipped if safety checks are implemented)
"""
# if not hasattr(self, 'puppet_arm_publishers'):
# # Initialize publishers on first call
# self._init_action_publishers()
# 默认速度和力矩值
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
0.0, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
# Convert tensor to numpy array if needed
if isinstance(action, torch.Tensor):
@@ -359,22 +371,23 @@ class AgilexRobot(Robot):
arm_velocity = last_velocity[from_idx:to_idx]
arm_effort = last_effort[from_idx:to_idx]
from_idx = to_idx
# fix
arm_action[-1] = max(arm_action[-1]*15, 0)
# Apply safety checks if configured
if 'max_relative_target' in self.config:
# Get current position from the queue
if len(self.sync_arm_queues[arm_name]) > 0:
current_state = self.sync_arm_queues[arm_name][-1]
current_pos = np.array(current_state.position)
# Clip the action to stay within max relative target
max_delta = self.config['max_relative_target']
clipped_action = np.clip(arm_action,
current_pos - max_delta,
current_pos + max_delta)
arm_action = clipped_action
# # Get current position from the queue
# if len(arm_action) > 0:
# # Clip the action to stay within max relative target
# max_delta = 0.1
# clipped_action = np.clip(arm_action,
# arm_action - max_delta,
# arm_action + max_delta)
# arm_action = clipped_action
action_sent.append(arm_action)
# action_sent.append(arm_action)
# Create and publish JointState message
joint_state = JointState()

View File

@@ -0,0 +1,422 @@
import yaml
from collections import deque
import rospy
from cv_bridge import CvBridge
from typing import Dict, Any, Optional, List
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
import argparse
class RobotConfig:
"""Configuration management for robot components"""
def __init__(self, config_file: str):
"""
Initialize robot configuration from YAML file
Args:
config_file: Path to YAML configuration file
"""
self.config = self._load_yaml(config_file)
self._validate_config()
def _load_yaml(self, config_file: str) -> Dict[str, Any]:
"""Load configuration from YAML file"""
with open(config_file, 'r') as f:
return yaml.safe_load(f)
def _validate_config(self) -> None:
"""Validate configuration completeness"""
required_sections = ['cameras', 'arm']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required config section: {section}")
def merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
"""
Merge runtime arguments into configuration
Args:
args: Runtime arguments from command line
"""
if args is None:
return
runtime_params = {
'frame_rate': getattr(args, 'fps', None),
'max_timesteps': getattr(args, 'max_timesteps', None),
'episode_idx': getattr(args, 'episode_idx', None),
'use_depth_image': getattr(args, 'use_depth_image', None),
'use_robot_base': getattr(args, 'use_base', None),
'video': getattr(args, 'video', False),
'control_type': getattr(args, 'control_type', False),
}
for key, value in runtime_params.items():
if value is not None:
self.config[key] = value
def get(self, key: str, default=None) -> Any:
"""Get configuration value with default fallback"""
return self.config.get(key, default)
class RosAdapter:
"""Adapter for ROS communication"""
def __init__(self, config: RobotConfig):
"""
Initialize ROS adapter
Args:
config: Robot configuration
"""
self.config = config
self.bridge = CvBridge()
self.subscribers = {}
self.publishers = {}
def init_ros_node(self, node_name: str = None) -> None:
"""Initialize ROS node"""
if node_name is None:
node_name = self.config.get('ros_node_name', 'generic_robot_node')
rospy.init_node(node_name, anonymous=True)
def create_subscriber(self, topic: str, msg_type, callback, queue_size: int = 1000, tcp_nodelay: bool = True):
"""Create a ROS subscriber"""
subscriber = rospy.Subscriber(
topic,
msg_type,
callback,
queue_size=queue_size,
tcp_nodelay=tcp_nodelay
)
return subscriber
def create_publisher(self, topic: str, msg_type, queue_size: int = 10):
"""Create a ROS publisher"""
publisher = rospy.Publisher(
topic,
msg_type,
queue_size=queue_size
)
return publisher
def log_status(self) -> None:
"""Log ROS connection status"""
rospy.loginfo("\n=== ROS订阅状态 ===")
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
rospy.loginfo("活跃的订阅者:")
for topic, sub in self.subscribers.items():
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
rospy.loginfo("=================")
class RobotSensors:
"""Management of robot sensors (cameras, depth sensors)"""
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
"""
Initialize robot sensors
Args:
config: Robot configuration
ros_adapter: ROS communication adapter
"""
self.config = config
self.ros_adapter = ros_adapter
self.bridge = ros_adapter.bridge
# Camera data
self.cameras = config.get('cameras', {})
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
# Depth data
self.use_depth_image = config.get('use_depth_image', False)
if self.use_depth_image:
self.sync_depth_queues = {
name: deque(maxlen=2000)
for name, cam in self.cameras.items()
if 'depth_topic_name' in cam
}
# Robot base data
self.use_robot_base = config.get('use_robot_base', False)
if self.use_robot_base:
self.sync_base_queue = deque(maxlen=2000)
def setup_subscribers(self) -> None:
"""Set up ROS subscribers for sensors"""
self._setup_camera_subscribers()
if self.use_robot_base:
self._setup_base_subscriber()
def _setup_camera_subscribers(self) -> None:
"""Set up camera subscribers"""
for cam_name, cam_config in self.cameras.items():
if 'img_topic_name' in cam_config:
self.ros_adapter.subscribers[f"camera_{cam_name}"] = self.ros_adapter.create_subscriber(
cam_config['img_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=False)
)
if self.use_depth_image and 'depth_topic_name' in cam_config:
self.ros_adapter.subscribers[f"depth_{cam_name}"] = self.ros_adapter.create_subscriber(
cam_config['depth_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=True)
)
def _setup_base_subscriber(self) -> None:
"""Set up base subscriber"""
if 'robot_base' in self.config.config:
self.ros_adapter.subscribers['base'] = self.ros_adapter.create_subscriber(
self.config.get('robot_base')['topic_name'],
Odometry,
self.robot_base_callback
)
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
"""Generate camera callback factory method"""
def callback(msg):
try:
target_queue = (
self.sync_depth_queues[cam_name]
if is_depth
else self.sync_img_queues[cam_name]
)
if len(target_queue) >= 2000:
target_queue.popleft()
target_queue.append(msg)
except Exception as e:
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
return callback
def robot_base_callback(self, msg):
"""Base callback default implementation"""
if len(self.sync_base_queue) >= 2000:
self.sync_base_queue.popleft()
self.sync_base_queue.append(msg)
def init_features(self) -> Dict[str, Any]:
"""Initialize sensor features"""
features = {}
# Initialize camera features
self._init_camera_features(features)
# Initialize base features (if enabled)
if self.use_robot_base:
self._init_base_features(features)
return features
def _init_camera_features(self, features: Dict[str, Any]) -> None:
"""Process all camera features"""
for cam_name, cam_config in self.cameras.items():
# Regular images
features[f"observation.images.{cam_name}"] = {
"dtype": "video" if self.config.get("video", False) else "image",
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
"names": ["height", "width", "channel"],
}
if self.config.get("use_depth_image", False):
features[f"observation.images.depth_{cam_name}"] = {
"dtype": "uint16",
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
"names": ["height", "width"],
}
def _init_base_features(self, features: Dict[str, Any]) -> None:
"""Process base features"""
features["observation.base_vel"] = {
"dtype": "float32",
"shape": (2,),
"names": ["linear_x", "angular_z"]
}
class RobotActuators:
"""Management of robot actuators (arms, base)"""
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
"""
Initialize robot actuators
Args:
config: Robot configuration
ros_adapter: ROS communication adapter
"""
self.config = config
self.ros_adapter = ros_adapter
# Arm data
self.arms = config.get('arm', {})
if config.get('control_type', '') != 'record':
# If not in record mode, only initialize puppet arm queues
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
else:
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
def setup_subscribers_publishers(self) -> None:
"""Set up ROS subscribers and publishers for actuators"""
self._setup_arm_subscribers_publishers()
def _setup_arm_subscribers_publishers(self) -> None:
"""Set up arm subscribers and publishers"""
# When in record mode, subscribe to both master and puppet arms
# Otherwise only subscribe to puppet arms, but publish to master arms
if self.config.get('control_type', '') == 'record':
for arm_name, arm_config in self.arms.items():
if 'topic_name' in arm_config:
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name)
)
else:
for arm_name, arm_config in self.arms.items():
if 'puppet' in arm_name:
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name)
)
if 'master' in arm_name:
self.ros_adapter.publishers[f"arm_{arm_name}"] = self.ros_adapter.create_publisher(
arm_config['topic_name'],
JointState
)
def _make_arm_callback(self, arm_name: str):
"""Generate arm callback factory method"""
def callback(msg):
try:
if len(self.sync_arm_queues[arm_name]) >= 2000:
self.sync_arm_queues[arm_name].popleft()
self.sync_arm_queues[arm_name].append(msg)
except Exception as e:
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
return callback
def init_features(self) -> Dict[str, Any]:
"""Initialize actuator features"""
features = {}
# Initialize arm features
self._init_state_features(features)
self._init_action_features(features)
return features
def _init_state_features(self, features: Dict[str, Any]) -> None:
"""Initialize state features"""
state = self.config.get('state', {})
# State features
features["observation.state"] = {
"dtype": "float32",
"shape": (len(state.get('motors', "")),),
"names": {"motors": state.get('motors', "")}
}
# if self.config.get('velocity'):
# velocity = self.config.get('velocity', "")
# features["observation.velocity"] = {
# "dtype": "float32",
# "shape": (len(velocity.get('motors', "")),),
# "names": {"motors": velocity.get('motors', "")}
# }
# if self.config.get('effort'):
# effort = self.config.get('effort', "")
# features["observation.effort"] = {
# "dtype": "float32",
# "shape": (len(effort.get('motors', "")),),
# "names": {"motors": effort.get('motors', "")}
# }
def _init_action_features(self, features: Dict[str, Any]) -> None:
"""Initialize action features"""
action = self.config.get('action', {})
features["action"] = {
"dtype": "float32",
"shape": (len(action.get('motors', "")),),
"names": {"motors": action.get('motors', "")}
}
class RobotDataManager:
"""Management of robot data collection and synchronization"""
def __init__(self, config: RobotConfig, sensors: RobotSensors, actuators: RobotActuators):
"""
Initialize robot data manager
Args:
config: Robot configuration
sensors: Robot sensors component
actuators: Robot actuators component
"""
self.config = config
self.sensors = sensors
self.actuators = actuators
def warmup(self, timeout: float = 30.0) -> bool:
"""
Wait until all data queues have sufficient messages
Args:
timeout: Maximum time to wait in seconds
Returns:
bool: True if warmup succeeded, False if timed out
"""
start_time = rospy.Time.now().to_sec()
rate = rospy.Rate(10) # Check at 10Hz
rospy.loginfo("Starting warmup process...")
while not rospy.is_shutdown():
# Check if timeout has been reached
current_time = rospy.Time.now().to_sec()
if current_time - start_time > timeout:
rospy.logwarn("Warmup timed out before all queues were filled")
return False
# Check all required queues
all_ready = True
# Check camera image queues
rospy.loginfo(f"Nums of camera is {len(self.sensors.cameras)}")
for cam_name in self.sensors.cameras:
if len(self.sensors.sync_img_queues[cam_name]) < 200:
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)")
all_ready = False
break
# Check depth queues if enabled
if self.sensors.use_depth_image:
for cam_name in self.sensors.sync_depth_queues:
if len(self.sensors.sync_depth_queues[cam_name]) < 200:
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)")
all_ready = False
break
# Check base queue if enabled
if self.sensors.use_robot_base:
if len(self.sensors.sync_base_queue) < 20:
rospy.loginfo(f"Waiting for base (current: {len(self.sensors.sync_base_queue)}/20)")
all_ready = False
# If all queues are ready, return success
if all_ready:
rospy.loginfo("Warmup completed successfully")
return True
rate.sleep()
return False

View File

@@ -0,0 +1,136 @@
import yaml
from typing import Dict, Any, Optional, List
import argparse
from .robot_components import RobotConfig, RosAdapter, RobotSensors, RobotActuators, RobotDataManager
class Robot:
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
"""
机器人基类,处理通用初始化逻辑
Args:
config_file: YAML配置文件路径
args: 运行时参数
"""
# 初始化组件
self.config = RobotConfig(config_file)
self.config.merge_runtime_args(args)
self.ros_adapter = RosAdapter(self.config)
self.sensors = RobotSensors(self.config, self.ros_adapter)
self.actuators = RobotActuators(self.config, self.ros_adapter)
self.data_manager = RobotDataManager(self.config, self.sensors, self.actuators)
# 初始化ROS和特征
self.init_ros()
self.init_features()
self.warmup()
def get(self, key: str, default=None) -> Any:
"""获取配置值"""
return self.config.get(key, default)
@property
def bridge(self):
"""获取CV桥接器"""
return self.ros_adapter.bridge
@property
def subscribers(self):
"""获取订阅者"""
return self.ros_adapter.subscribers
@property
def publishers(self):
"""获取发布者"""
return self.ros_adapter.publishers
@property
def cameras(self):
"""获取相机配置"""
return self.sensors.cameras
@property
def arms(self):
"""获取机械臂配置"""
return self.actuators.arms
@property
def sync_img_queues(self):
"""获取图像队列"""
return self.sensors.sync_img_queues
@property
def sync_depth_queues(self):
"""获取深度图像队列"""
return self.sensors.sync_depth_queues if hasattr(self.sensors, 'sync_depth_queues') else {}
@property
def sync_arm_queues(self):
"""获取机械臂队列"""
return self.actuators.sync_arm_queues
@property
def sync_base_queue(self):
"""获取基座队列"""
return self.sensors.sync_base_queue if hasattr(self.sensors, 'sync_base_queue') else None
@property
def use_depth_image(self):
"""是否使用深度图像"""
return self.sensors.use_depth_image
@property
def use_robot_base(self):
"""是否使用机器人基座"""
return self.sensors.use_robot_base
def init_ros(self) -> None:
"""初始化ROS订阅的模板方法"""
self.ros_adapter.init_ros_node()
# 设置传感器和执行器的订阅者和发布者
self.sensors.setup_subscribers()
self.actuators.setup_subscribers_publishers()
# 记录ROS状态
self.ros_adapter.log_status()
def init_features(self):
"""
根据YAML配置自动生成features结构
"""
# 合并传感器和执行器的特征
self.features = {}
self.features.update(self.sensors.init_features())
self.features.update(self.actuators.init_features())
import pprint
pprint.pprint(self.features, indent=4)
def warmup(self, timeout: float = 30.0) -> bool:
"""Wait until all data queues have at least 20 messages.
Args:
timeout: Maximum time to wait in seconds before giving up
Returns:
bool: True if warmup succeeded, False if timed out
"""
return self.data_manager.warmup(timeout)
def get_frame(self) -> Optional[Dict[str, Any]]:
"""获取同步帧数据的模板方法"""
raise NotImplementedError("Subclasses must implement get_frame()")
def process(self) -> tuple:
"""主处理循环的模板方法"""
raise NotImplementedError("Subclasses must implement process()")

View File

@@ -0,0 +1,59 @@
import yaml
import argparse
from typing import Dict, List, Any, Optional, Type
from .rosrobot import Robot
from .agilex_robot import AgilexRobot
class RobotFactory:
"""Factory for creating robot instances based on configuration"""
# 注册表,用于存储可用的机器人类型
_registry = {}
@classmethod
def register(cls, robot_type: str, robot_class: Type[Robot]) -> None:
"""
注册新的机器人类型
Args:
robot_type: 机器人类型标识符
robot_class: 机器人类实现
"""
cls._registry[robot_type] = robot_class
@classmethod
def create(cls, config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
"""
根据配置文件自动创建合适的机器人实例
Args:
config_file: 配置文件路径
args: 运行时参数
Returns:
Robot: 创建的机器人实例
Raises:
ValueError: 如果指定的机器人类型不受支持
"""
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
robot_type = config.get('robot_type', 'agilex')
# 如果注册表为空,注册默认机器人类型
if not cls._registry:
cls.register('agilex', AgilexRobot)
cls.register('aloha_agilex', AgilexRobot) # 别名支持
# 从注册表中查找机器人类
if robot_type in cls._registry:
return cls._registry[robot_type](config_file, args)
else:
raise ValueError(f"Unsupported robot type: {robot_type}. Available types: {list(cls._registry.keys())}")
# 注册可用的机器人类型
RobotFactory.register('agilex', AgilexRobot)
RobotFactory.register('aloha_agilex', AgilexRobot) # 别名支持

View File

@@ -0,0 +1,12 @@
# Import utility functions for easy access
from .control_utils import (
predict_action,
control_loop,
init_keyboard_listener,
stop_recording,
record_episode,
is_headless,
busy_wait
)
from .data_utils import record
from .replay_utils import replay

View File

@@ -0,0 +1,312 @@
import logging
import time
import torch
import rospy
import cv2
import numpy as np
from contextlib import nullcontext
from copy import copy
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import get_safe_torch_device
def is_headless():
"""
Check if the environment is headless (no display available).
Returns:
bool: True if the environment is headless, False otherwise.
"""
try:
import tkinter as tk
root = tk.Tk()
root.withdraw()
root.update()
root.destroy()
return False
except:
return True
def predict_action(observation, policy, device, use_amp):
"""
Predict action based on observation using the policy.
Args:
observation: Current observation
policy: Policy model
device: Torch device
use_amp: Whether to use automatic mixed precision
Returns:
torch.Tensor: Predicted action
"""
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy = None,
fps: int | None = None,
single_task: str | None = None,
):
"""
Main control loop for robot operation.
Args:
robot: Robot instance
control_time_s: Control time in seconds
teleoperate: Whether to use teleoperation
display_cameras: Whether to display camera feeds
dataset: Dataset for recording
events: Event dictionary
policy: Policy model
fps: Frames per second
single_task: Task name
"""
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
rate = rospy.Rate(fps)
print_flag = True
while timestamp < control_time_s and not rospy.is_shutdown():
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step()
if observation is None or action is None:
if print_flag:
print("sync data fail, retrying...\n")
print_flag = False
rate.sleep()
continue
else:
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key and "depth" not in key]
num_images = len(image_keys)
if num_images > 0:
# 设置每个图像的显示尺寸
display_width = 640 # 更小的宽度
display_height = 480 # 更小的高度
# 确定网格布局的行列数 (尽量接近正方形布局)
grid_cols = int(np.ceil(np.sqrt(num_images)))
grid_rows = int(np.ceil(num_images / grid_cols))
# 创建一个大的画布来容纳所有图像
canvas = np.zeros((grid_rows * display_height, grid_cols * display_width, 3), dtype=np.uint8)
# 在画布上放置每个图像
for idx, key in enumerate(image_keys):
row = idx // grid_cols
col = idx % grid_cols
# 获取图像并转换为BGR
image = observation[key].numpy()
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# 调整图像大小
resized_image = cv2.resize(image, (display_width, display_height))
# 计算在画布上的位置
y_start = row * display_height
y_end = y_start + display_height
x_start = col * display_width
x_end = x_start + display_width
# 将图像放置到画布上
canvas[y_start:y_end, x_start:x_end] = resized_image
# 添加图像标题
title_position = (x_start + 5, y_start + 15)
cv2.putText(canvas, key, title_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# 显示合并后的画布
cv2.imshow("Camera Views", canvas)
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def init_keyboard_listener():
"""
Initialize keyboard listener for control events.
Returns:
tuple: (listener, events) - Keyboard listener and events dictionary
"""
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["record_start"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
events["record_start"] = False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.up:
print("Up arrow pressed. Start data recording...")
events["record_start"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def stop_recording(robot, listener, display_cameras):
"""
Stop recording and clean up resources.
Args:
robot: Robot instance
listener: Keyboard listener
display_cameras: Whether cameras are being displayed
"""
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
fps,
single_task,
):
"""
Record a single episode.
Args:
robot: Robot instance
dataset: Dataset for recording
events: Event dictionary
episode_time_s: Episode time in seconds
display_cameras: Whether to display camera feeds
policy: Policy model
fps: Frames per second
single_task: Task name
"""
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
def busy_wait(seconds):
"""
Busy wait for a specified number of seconds.
Args:
seconds: Number of seconds to wait
"""
if seconds <= 0:
return
start_time = time.perf_counter()
while time.perf_counter() - start_time < seconds:
pass

View File

@@ -0,0 +1,105 @@
import logging
import time
from pprint import pprint
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import log_say, has_method
from common.utils.control_utils import init_keyboard_listener, stop_recording, record_episode
def record(
robot,
cfg
) -> LeRobotDataset:
"""
Record robot data according to configuration.
Args:
robot: Robot instance
cfg: Configuration object
Returns:
LeRobotDataset: Dataset with recorded episodes
"""
# Initialize or load dataset
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
# Create empty dataset or load existing saved episodes
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# Initialize keyboard listener
listener, events = init_keyboard_listener()
# Print recording instructions
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# Record episodes
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset

View File

@@ -0,0 +1,32 @@
import time
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from common.utils.control_utils import busy_wait
def replay(
robot,
cfg,
):
"""
Replay recorded robot data according to configuration.
Args:
robot: Robot instance
cfg: Configuration object
"""
# Load dataset
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
print(f"Replaying episode {cfg.episode} from dataset {cfg.repo_id}")
print(f"Total frames: {dataset.num_frames}")
# Replay each frame
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)

View File

@@ -1,12 +1,20 @@
robot_type: aloha_agilex
ros_node_name: record_episodes
cameras:
cam_front:
cam_high:
# img_topic_name: /camera/color/image_raw
# depth_topic_name: /camera/depth/image_rect_raw
img_topic_name: /camera_f/color/image_raw
depth_topic_name: /camera_f/depth/image_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
rgb_shape: [480, 640, 3]
# cam_front:
# img_topic_name: /camera_f/color/image_raw
# depth_topic_name: /camera_f/depth/image_raw
# width: 480
# height: 640
# rgb_shape: [480, 640, 3]
cam_left:
img_topic_name: /camera_l/color/image_raw
depth_topic_name: /camera_l/depth/image_raw
@@ -20,6 +28,7 @@ cameras:
width: 480
height: 640
arm:
master_left:
topic_name: /master/joint_left
@@ -85,41 +94,41 @@ state:
"right_none"
]
velocity:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
# velocity:
# motors: [
# "left_joint0",
# "left_joint1",
# "left_joint2",
# "left_joint3",
# "left_joint4",
# "left_joint5",
# "left_none",
# "right_joint0",
# "right_joint1",
# "right_joint2",
# "right_joint3",
# "right_joint4",
# "right_joint5",
# "right_none"
# ]
effort:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
# effort:
# motors: [
# "left_joint0",
# "left_joint1",
# "left_joint2",
# "left_joint3",
# "left_joint4",
# "left_joint5",
# "left_none",
# "right_joint0",
# "right_joint1",
# "right_joint2",
# "right_joint3",
# "right_joint4",
# "right_joint5",
# "right_none"
# ]
action:
motors: [

239
lerobot_aloha/gui_app.py Normal file
View File

@@ -0,0 +1,239 @@
import sys
import numpy as np
import cv2
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
QPushButton, QGroupBox, QTabWidget, QScrollArea, QGridLayout)
from PyQt5.QtCore import Qt, QTimer
import cv2
from PyQt5.QtGui import QImage, QPixmap
from main import get_arguments, control_robot
class ConfigGroup(QGroupBox):
"""Group of configuration widgets"""
def __init__(self, title, parent=None):
super().__init__(title, parent)
self.layout = QVBoxLayout()
self.setLayout(self.layout)
def add_config(self, name, value, widget_type="lineedit"):
"""Add a configuration widget"""
row = QHBoxLayout()
label = QLabel(name)
if isinstance(value, bool):
widget = QCheckBox()
widget.setChecked(value)
elif isinstance(value, int):
widget = QSpinBox()
widget.setRange(0, 999999)
widget.setValue(value)
else: # string or other
widget = QLineEdit(str(value))
row.addWidget(label)
row.addWidget(widget)
self.layout.addLayout(row)
return widget
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("MindRobot-V1 Control GUI")
self.setGeometry(100, 100, 600, 800) # Adjusted window size
self.robot = None
# Get default arguments
self.cfg = None
# Main layout
main_widget = QWidget()
main_layout = QHBoxLayout()
main_widget.setLayout(main_layout)
# Left panel - configuration
config_scroll = QScrollArea()
config_widget = QWidget()
config_layout = QVBoxLayout()
config_widget.setLayout(config_layout)
# Add configuration groups
general_group = ConfigGroup("General Settings")
self.single_task_widget = general_group.add_config("Single Task", "")
self.single_task_widget.setMinimumWidth(300) # Wider text box
self.single_task_widget.textChanged.connect(self.update_repo_id_from_task)
self.fps_widget = general_group.add_config("FPS", 30, "spinbox")
self.resume_widget = general_group.add_config("Resume", False, "checkbox")
self.repo_id_widget = general_group.add_config("Repo ID", "")
self.repo_id_widget.textChanged.connect(self.update_root_from_repo_id)
self.original_repo_id = "" # Store original value
self.play_sounds_widget = general_group.add_config("Play Sounds", False, "checkbox")
# Config file with browse button
config_row = QHBoxLayout()
config_label = QLabel("Config File")
self.config_widget = QLineEdit("/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml")
config_browse_button = QPushButton("Browse...")
config_browse_button.clicked.connect(self.browse_config_file)
config_row.addWidget(config_label)
config_row.addWidget(self.config_widget)
config_row.addWidget(config_browse_button)
general_group.layout.addLayout(config_row)
# Root directory with browse button
root_row = QHBoxLayout()
root_label = QLabel("Root Directory")
self.root_widget = QLineEdit(str("/home/ubuntu/LYT/lerobot_aloha/datasets/"+self.repo_id_widget.text()))
browse_button = QPushButton("Browse...")
browse_button.clicked.connect(self.browse_root_directory)
root_row.addWidget(root_label)
root_row.addWidget(self.root_widget)
root_row.addWidget(browse_button)
general_group.layout.addLayout(root_row)
config_layout.addWidget(general_group)
recording_group = ConfigGroup("Recording Settings")
self.episode_widget = recording_group.add_config("Episode", 0, "spinbox")
self.num_episodes_widget = recording_group.add_config("Number of Episodes", 100, "spinbox")
self.episode_time_widget = recording_group.add_config("Episode Time (ms)", 36000, "spinbox")
self.video_widget = recording_group.add_config("Save Video", True, "checkbox")
self.display_cameras_widget = recording_group.add_config("Display Cameras", True, "checkbox")
config_layout.addWidget(recording_group)
advanced_group = ConfigGroup("Advanced Settings")
self.num_writer_processes_widget = advanced_group.add_config("Writer Processes", 0, "spinbox")
self.num_writer_threads_widget = advanced_group.add_config("Threads per Camera", 4, "spinbox")
self.use_depth_widget = advanced_group.add_config("Use Depth Image", False, "checkbox")
self.use_base_widget = advanced_group.add_config("Use Base", False, "checkbox")
self.push_to_hub_widget = advanced_group.add_config("Push to Hub", False, "checkbox")
self.policy_widget = advanced_group.add_config("Policy", None)
config_layout.addWidget(advanced_group)
# Control buttons
control_buttons = QHBoxLayout()
self.record_button = QPushButton("Record")
self.record_button.clicked.connect(self.start_recording)
control_buttons.addWidget(self.record_button)
self.stop_button = QPushButton("Stop")
self.stop_button.clicked.connect(self.stop_recording)
control_buttons.addWidget(self.stop_button)
config_layout.addLayout(control_buttons)
config_scroll.setWidget(config_widget)
config_scroll.setWidgetResizable(True)
main_layout.addWidget(config_scroll, stretch=1) # Left panel stretch factor
# Remove camera view panel completely
self.setCentralWidget(main_widget)
# Robot control flag
self.is_recording = False
def browse_config_file(self):
"""Open file dialog to select config file"""
from PyQt5.QtWidgets import QFileDialog
file_path, _ = QFileDialog.getOpenFileName(
self,
"Select Config File",
self.config_widget.text(),
"YAML Files (*.yaml *.yml)"
)
if file_path:
self.config_widget.setText(file_path)
def update_repo_id_from_task(self, text):
"""Update repo_id from single_task text, replacing spaces with underscores"""
if not hasattr(self, 'repo_id_edited') or not self.repo_id_edited:
self.repo_id_widget.setText(text.replace(" ", "_"))
def update_root_from_repo_id(self, text):
"""Update root directory based on repo_id"""
if self.root_widget:
self.root_widget.setText(f"/home/ubuntu/LYT/lerobot_aloha/datasets/{text}")
def browse_root_directory(self):
"""Open file dialog to select root directory"""
from PyQt5.QtWidgets import QFileDialog
dir_path = QFileDialog.getExistingDirectory(
self,
"Select Root Directory",
self.root_widget.text()
)
if dir_path:
self.root_widget.setText(dir_path)
def get_config_values(self):
"""Get current configuration values from UI"""
self.cfg.fps = self.fps_widget.value()
self.cfg.resume = self.resume_widget.isChecked()
self.cfg.repo_id = self.repo_id_widget.text()
self.repo_id_edited = (self.cfg.repo_id != self.original_repo_id)
self.cfg.root = self.root_widget.text()
self.cfg.episode = self.episode_widget.value()
self.cfg.num_episodes = self.num_episodes_widget.value()
self.cfg.episode_time_s = self.episode_time_widget.value()
self.cfg.video = self.video_widget.isChecked()
self.cfg.display_cameras = self.display_cameras_widget.isChecked()
self.cfg.play_sounds = self.play_sounds_widget.isChecked()
self.cfg.single_task = self.single_task_widget.text()
self.cfg.num_image_writer_processes = self.num_writer_processes_widget.value()
self.cfg.num_image_writer_threads_per_camera = self.num_writer_threads_widget.value()
self.cfg.use_depth_image = self.use_depth_widget.isChecked()
self.cfg.use_base = self.use_base_widget.isChecked()
self.cfg.push_to_hub = self.push_to_hub_widget.isChecked()
self.cfg.policy = None
self.cfg.control_type = "record"
def start_recording(self):
"""Start recording with current configuration"""
if not hasattr(self, 'cfg') or self.cfg is None:
from main import get_arguments
self.cfg = get_arguments()
if self.fps_widget is None:
return
self.get_config_values()
try:
# Create robot instance with current config
from common.rosrobot_factory import RobotFactory
self.cfg.config_file = self.config_widget.text()
self.robot = RobotFactory.create(
config_file=self.cfg.config_file,
args=self.cfg
)
from common.utils.data_utils import record
record(self.robot, self.cfg)
self.is_recording = True
self.record_button.setEnabled(False)
self.stop_button.setEnabled(True)
print("Recording started with configuration:", vars(self.cfg))
except Exception as e:
print(f"Failed to start recording: {e}")
self.is_recording = False
def stop_recording(self):
"""Stop recording"""
self.is_recording = False
self.record_button.setEnabled(True)
self.stop_button.setEnabled(False)
# 模拟ESC键按下事件
from pynput.keyboard import Key, Controller
keyboard = Controller()
keyboard.press(Key.esc)
keyboard.release(Key.esc)
print("Recording stopped")
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec_())

56
lerobot_aloha/main.py Normal file
View File

@@ -0,0 +1,56 @@
import argparse
from common.rosrobot_factory import RobotFactory
from common.utils.data_utils import record
from common.utils.replay_utils import replay
import cv2
def get_arguments():
"""
Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "./data5"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 100
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "move the bottle from the right to the scale right"
args.use_depth_image = True
args.use_base = False
args.push_to_hub = False
args.policy = None
args.control_type = "record"
return args
def control_robot(cfg):
"""
Control robot based on configuration.
Args:
cfg: Configuration object
"""
# Create robot instance using factory pattern
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
# Execute appropriate control mode
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)

View File

@@ -0,0 +1,78 @@
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.robot_devices.utils import busy_wait
import time
import argparse
from common.agilex_robot import AgilexRobot
import torch
import cv2
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import cycle
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "tangger/test"
# args.root = "/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_tube_on_the_scale_without_front"
# args.root="/home/ubuntu/LYT/aloha_lerobot/data4"
args.root = "/home/ubuntu/LYT/lerobot_aloha/datasets/abcde"
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 50
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "test test"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy= None
args.teleoprate = False
return args
cfg = get_arguments()
robot = AgilexRobot(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
inference_time_s = 360
fps = 15
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
shuffle = True
sampler = None
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=4,
batch_size=1,
shuffle=False,
sampler=sampler,
pin_memory=device != "cpu",
drop_last=False,
)
dl_iter = cycle(dataloader)
# 控制播放速度fps=30
for data in dl_iter:
start_time = time.perf_counter()
action = data["action"]
# cam_high = data["observation.images.cam_high"]
# cam_left = data["observation.images.cam_left"]
# cam_right = data["observation.images.cam_right"]
# print(data)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
# Order the robot to move
robot.send_action(action)
print(action)
dt_s = time.perf_counter() - start_time
busy_wait(1 / fps - dt_s)

View File

@@ -1,9 +1,12 @@
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.robot_devices.utils import busy_wait
import time
import argparse
from agilex_robot import AgilexRobot
from common.agilex_robot import AgilexRobot
import torch
import cv2
def get_arguments():
parser = argparse.ArgumentParser()
@@ -29,13 +32,18 @@ def get_arguments():
cfg = get_arguments()
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
robot = AgilexRobot(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
inference_time_s = 360
fps = 30
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
policy = ACTPolicy.from_pretrained(ckpt_path)
# ckpt_path = "/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_move_bottle_on_scale_without_front/checkpoints/last/pretrained_model"
ckpt_path ="/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_abcde/checkpoints/last/pretrained_model"
policy = ACTPolicy.from_pretrained(pretrained_name_or_path=ckpt_path)
# ckpt_path ="/home/ubuntu/LYT/lerobot_aloha/outputs/train/diffusion_abcde/checkpoints/020000/pretrained_model"
# policy = DiffusionPolicy.from_pretrained(pretrained_name_or_path=ckpt_path)
policy.to(device)
for _ in range(inference_time_s * fps):
@@ -46,16 +54,23 @@ for _ in range(inference_time_s * fps):
if observation is None:
print("Observation is None, skipping...")
continue
# visualize the image in the obervation
# cv2.imshow("observation", observation["observation.image"])
# Convert to pytorch format: channel first and float32 in [0,1]
# with batch dimension
for name in observation:
if "image" in name:
img = observation[name].numpy()
# cv2.imshow(name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
# cv2.imwrite(f"{name}.png", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
last_pos = observation["observation.state"]
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
@@ -65,6 +80,8 @@ for _ in range(inference_time_s * fps):
action = action.to("cpu")
# Order the robot to move
robot.send_action(action)
print("left pos:", action[:7])
print("right pos:", action[7:])
dt_s = time.perf_counter() - start_time
busy_wait(1 / fps - dt_s)

1
openpi Submodule

Submodule openpi added at 36dc3c037e

BIN
test.pt

Binary file not shown.