Files
lerobot_aloha/collect_data/collect_data_gui.py

417 lines
18 KiB
Python

import os
import sys
import time
import argparse
from aloha_mobile import AlohaRobotRos
from utils import save_data, init_keyboard_listener
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox,
QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog,
QMessageBox, QProgressBar, QComboBox)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot
from PyQt5.QtGui import QFont, QIcon, QTextCursor
class DataCollectionThread(QThread):
"""处理数据收集的线程"""
update_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
finish_signal = pyqtSignal(bool, str)
def __init__(self, args, parent=None):
super(DataCollectionThread, self).__init__(parent)
self.args = args
self.is_running = True
self.ros_operator = None
def run(self):
try:
self.update_signal.emit("正在初始化ROS操作...\n")
self.ros_operator = AlohaRobotRos(self.args)
dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name)
os.makedirs(dataset_dir, exist_ok=True)
# 单集收集模式
if self.args.num_episodes == 1:
self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n")
timesteps, actions = self.ros_operator.process()
if len(actions) < self.args.max_timesteps:
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步")
return
dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}")
save_data(self.args, timesteps, actions, dataset_path)
self.update_signal.emit(f"{self.args.episode_idx} 集成功保存到 {dataset_path}.\n")
self.finish_signal.emit(True, "数据收集完成")
# 多集收集模式
else:
self.update_signal.emit("""
键盘控制:
← 左箭头: 开始录制
→ 右箭头: 保存当前数据
↓ 下箭头: 丢弃当前数据
ESC: 退出程序
""")
# 初始化键盘监听器
listener, events = init_keyboard_listener()
episode_idx = self.args.episode_idx
collected_episodes = 0
try:
while collected_episodes < self.args.num_episodes and self.is_running:
if events["exit_early"]:
self.update_signal.emit("操作被用户终止.\n")
break
if events["record_start"]:
# 重置事件状态,开始新的录制
events["record_start"] = False
events["save_data"] = False
events["discard_data"] = False
self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n")
timesteps, actions = self.ros_operator.process()
self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n")
# 等待用户决定保存或丢弃
while self.is_running:
if events["save_data"]:
events["save_data"] = False
if len(actions) < self.args.max_timesteps:
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
else:
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
save_data(self.args, timesteps, actions, dataset_path)
self.update_signal.emit(f"{episode_idx} 集成功保存到 {dataset_path}.\n")
episode_idx += 1
collected_episodes += 1
progress_percentage = int(collected_episodes * 100 / self.args.num_episodes)
self.progress_signal.emit(progress_percentage)
self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n")
break
if events["discard_data"]:
events["discard_data"] = False
self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n")
break
if events["exit_early"]:
self.update_signal.emit("操作被用户终止.\n")
self.is_running = False
break
time.sleep(0.1) # 减少CPU使用率
time.sleep(0.1) # 减少CPU使用率
if collected_episodes == self.args.num_episodes:
self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n")
self.finish_signal.emit(True, "全部数据集收集完成")
else:
self.finish_signal.emit(False, "数据收集未完成")
finally:
# 确保监听器被清理
if listener:
listener.stop()
self.update_signal.emit("键盘监听器已停止\n")
except Exception as e:
self.update_signal.emit(f"错误: {str(e)}\n")
self.finish_signal.emit(False, str(e))
def stop(self):
self.is_running = False
self.wait()
class AlohaDataCollectionGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("ALOHA 数据收集工具")
self.setGeometry(100, 100, 800, 700)
# 主组件
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.main_layout = QVBoxLayout(self.central_widget)
# 创建选项卡
self.tab_widget = QTabWidget()
self.main_layout.addWidget(self.tab_widget)
# 创建配置选项卡
self.config_tab = QWidget()
self.tab_widget.addTab(self.config_tab, "配置")
# 创建数据收集选项卡
self.collection_tab = QWidget()
self.tab_widget.addTab(self.collection_tab, "数据收集")
self.setup_config_tab()
self.setup_collection_tab()
# 初始化数据收集线程
self.collection_thread = None
def setup_config_tab(self):
config_layout = QVBoxLayout(self.config_tab)
# 基本配置组
basic_group = QGroupBox("基本配置")
basic_form = QFormLayout()
self.dataset_dir = QLineEdit("./data")
self.browse_button = QPushButton("浏览")
self.browse_button.clicked.connect(self.browse_dataset_dir)
dir_layout = QHBoxLayout()
dir_layout.addWidget(self.dataset_dir)
dir_layout.addWidget(self.browse_button)
self.task_name = QLineEdit("aloha_mobile_dummy")
self.episode_idx = QSpinBox()
self.episode_idx.setRange(0, 1000)
self.max_timesteps = QSpinBox()
self.max_timesteps.setRange(1, 10000)
self.max_timesteps.setValue(500)
self.num_episodes = QSpinBox()
self.num_episodes.setRange(1, 100)
self.num_episodes.setValue(1)
self.frame_rate = QSpinBox()
self.frame_rate.setRange(1, 120)
self.frame_rate.setValue(30)
basic_form.addRow("数据集目录:", dir_layout)
basic_form.addRow("任务名称:", self.task_name)
basic_form.addRow("起始集索引:", self.episode_idx)
basic_form.addRow("最大时间步:", self.max_timesteps)
basic_form.addRow("集数:", self.num_episodes)
basic_form.addRow("帧率:", self.frame_rate)
basic_group.setLayout(basic_form)
config_layout.addWidget(basic_group)
# 相机话题组
camera_group = QGroupBox("相机话题")
camera_form = QFormLayout()
self.img_front_topic = QLineEdit('/camera_f/color/image_raw')
self.img_left_topic = QLineEdit('/camera_l/color/image_raw')
self.img_right_topic = QLineEdit('/camera_r/color/image_raw')
self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw')
self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw')
self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw')
self.use_depth_image = QCheckBox("使用深度图像")
camera_form.addRow("前置相机:", self.img_front_topic)
camera_form.addRow("左腕相机:", self.img_left_topic)
camera_form.addRow("右腕相机:", self.img_right_topic)
camera_form.addRow("前置深度:", self.img_front_depth_topic)
camera_form.addRow("左腕深度:", self.img_left_depth_topic)
camera_form.addRow("右腕深度:", self.img_right_depth_topic)
camera_form.addRow("", self.use_depth_image)
camera_group.setLayout(camera_form)
config_layout.addWidget(camera_group)
# 机器人话题组
robot_group = QGroupBox("机器人话题")
robot_form = QFormLayout()
self.master_arm_left_topic = QLineEdit('/master/joint_left')
self.master_arm_right_topic = QLineEdit('/master/joint_right')
self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left')
self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right')
self.robot_base_topic = QLineEdit('/odom')
self.use_robot_base = QCheckBox("使用机器人底盘")
robot_form.addRow("主左臂:", self.master_arm_left_topic)
robot_form.addRow("主右臂:", self.master_arm_right_topic)
robot_form.addRow("从左臂:", self.puppet_arm_left_topic)
robot_form.addRow("从右臂:", self.puppet_arm_right_topic)
robot_form.addRow("底盘:", self.robot_base_topic)
robot_form.addRow("", self.use_robot_base)
robot_group.setLayout(robot_form)
config_layout.addWidget(robot_group)
# 相机名称配置
camera_names_group = QGroupBox("相机名称")
camera_names_layout = QVBoxLayout()
self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
self.camera_checkboxes = {}
for cam_name in self.camera_names:
self.camera_checkboxes[cam_name] = QCheckBox(cam_name)
self.camera_checkboxes[cam_name].setChecked(True)
camera_names_layout.addWidget(self.camera_checkboxes[cam_name])
camera_names_group.setLayout(camera_names_layout)
config_layout.addWidget(camera_names_group)
# 保存配置按钮
self.save_config_button = QPushButton("保存配置")
self.save_config_button.clicked.connect(self.save_config)
config_layout.addWidget(self.save_config_button)
def setup_collection_tab(self):
collection_layout = QVBoxLayout(self.collection_tab)
# 当前配置展示
config_group = QGroupBox("当前配置")
self.config_text = QTextEdit()
self.config_text.setReadOnly(True)
config_layout = QVBoxLayout()
config_layout.addWidget(self.config_text)
config_group.setLayout(config_layout)
collection_layout.addWidget(config_group)
# 操作按钮
buttons_layout = QHBoxLayout()
self.start_button = QPushButton("开始收集")
self.start_button.setIcon(QIcon.fromTheme("media-playback-start"))
self.start_button.clicked.connect(self.start_collection)
self.stop_button = QPushButton("停止")
self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop"))
self.stop_button.setEnabled(False)
self.stop_button.clicked.connect(self.stop_collection)
buttons_layout.addWidget(self.start_button)
buttons_layout.addWidget(self.stop_button)
collection_layout.addLayout(buttons_layout)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setValue(0)
collection_layout.addWidget(self.progress_bar)
# 日志输出
log_group = QGroupBox("操作日志")
self.log_text = QTextEdit()
self.log_text.setReadOnly(True)
log_layout = QVBoxLayout()
log_layout.addWidget(self.log_text)
log_group.setLayout(log_layout)
collection_layout.addWidget(log_group)
def browse_dataset_dir(self):
directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text())
if directory:
self.dataset_dir.setText(directory)
def save_config(self):
# 更新配置显示
selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()]
config_text = f"""
任务名称: {self.task_name.text()}
数据集目录: {self.dataset_dir.text()}
起始集索引: {self.episode_idx.value()}
最大时间步: {self.max_timesteps.value()}
集数: {self.num_episodes.value()}
帧率: {self.frame_rate.value()}
使用深度图像: {"" if self.use_depth_image.isChecked() else ""}
使用机器人底盘: {"" if self.use_robot_base.isChecked() else ""}
相机: {', '.join(selected_cameras)}
"""
self.config_text.setText(config_text)
self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡
QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集")
def start_collection(self):
if not self.task_name.text():
QMessageBox.warning(self, "配置错误", "请输入有效的任务名称")
return
# 构建参数
args = argparse.Namespace(
dataset_dir=self.dataset_dir.text(),
task_name=self.task_name.text(),
episode_idx=self.episode_idx.value(),
max_timesteps=self.max_timesteps.value(),
num_episodes=self.num_episodes.value(),
camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()],
img_front_topic=self.img_front_topic.text(),
img_left_topic=self.img_left_topic.text(),
img_right_topic=self.img_right_topic.text(),
img_front_depth_topic=self.img_front_depth_topic.text(),
img_left_depth_topic=self.img_left_depth_topic.text(),
img_right_depth_topic=self.img_right_depth_topic.text(),
master_arm_left_topic=self.master_arm_left_topic.text(),
master_arm_right_topic=self.master_arm_right_topic.text(),
puppet_arm_left_topic=self.puppet_arm_left_topic.text(),
puppet_arm_right_topic=self.puppet_arm_right_topic.text(),
robot_base_topic=self.robot_base_topic.text(),
use_robot_base=self.use_robot_base.isChecked(),
use_depth_image=self.use_depth_image.isChecked(),
frame_rate=self.frame_rate.value()
)
# 更新UI状态
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.progress_bar.setValue(0)
self.log_text.clear()
self.log_text.append("正在初始化数据收集...\n")
# 创建并启动线程
self.collection_thread = DataCollectionThread(args)
self.collection_thread.update_signal.connect(self.update_log)
self.collection_thread.progress_signal.connect(self.update_progress)
self.collection_thread.finish_signal.connect(self.collection_finished)
self.collection_thread.start()
def stop_collection(self):
if self.collection_thread and self.collection_thread.isRunning():
self.log_text.append("正在停止数据收集...\n")
self.collection_thread.stop()
self.collection_thread = None
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
@pyqtSlot(str)
def update_log(self, message):
self.log_text.append(message)
# 自动滚动到底部
cursor = self.log_text.textCursor()
cursor.movePosition(QTextCursor.End)
self.log_text.setTextCursor(cursor)
@pyqtSlot(int)
def update_progress(self, value):
self.progress_bar.setValue(value)
@pyqtSlot(bool, str)
def collection_finished(self, success, message):
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
if success:
QMessageBox.information(self, "完成", message)
else:
QMessageBox.warning(self, "出错", f"数据收集失败: {message}")
# 更新episode_idx值
if success and self.num_episodes.value() > 0:
self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value())
def main():
app = QApplication(sys.argv)
window = AlohaDataCollectionGUI()
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()