417 lines
18 KiB
Python
417 lines
18 KiB
Python
import os
|
|
import sys
|
|
import time
|
|
import argparse
|
|
from aloha_mobile import AlohaRobotRos
|
|
from utils import save_data, init_keyboard_listener
|
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
|
QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox,
|
|
QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog,
|
|
QMessageBox, QProgressBar, QComboBox)
|
|
from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot
|
|
from PyQt5.QtGui import QFont, QIcon, QTextCursor
|
|
|
|
class DataCollectionThread(QThread):
|
|
"""处理数据收集的线程"""
|
|
update_signal = pyqtSignal(str)
|
|
progress_signal = pyqtSignal(int)
|
|
finish_signal = pyqtSignal(bool, str)
|
|
|
|
def __init__(self, args, parent=None):
|
|
super(DataCollectionThread, self).__init__(parent)
|
|
self.args = args
|
|
self.is_running = True
|
|
self.ros_operator = None
|
|
|
|
def run(self):
|
|
try:
|
|
self.update_signal.emit("正在初始化ROS操作...\n")
|
|
self.ros_operator = AlohaRobotRos(self.args)
|
|
dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name)
|
|
os.makedirs(dataset_dir, exist_ok=True)
|
|
|
|
# 单集收集模式
|
|
if self.args.num_episodes == 1:
|
|
self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n")
|
|
timesteps, actions = self.ros_operator.process()
|
|
|
|
if len(actions) < self.args.max_timesteps:
|
|
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
|
|
self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步")
|
|
return
|
|
|
|
dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}")
|
|
save_data(self.args, timesteps, actions, dataset_path)
|
|
self.update_signal.emit(f"第 {self.args.episode_idx} 集成功保存到 {dataset_path}.\n")
|
|
self.finish_signal.emit(True, "数据收集完成")
|
|
|
|
# 多集收集模式
|
|
else:
|
|
self.update_signal.emit("""
|
|
键盘控制:
|
|
← 左箭头: 开始录制
|
|
→ 右箭头: 保存当前数据
|
|
↓ 下箭头: 丢弃当前数据
|
|
ESC: 退出程序
|
|
""")
|
|
# 初始化键盘监听器
|
|
listener, events = init_keyboard_listener()
|
|
episode_idx = self.args.episode_idx
|
|
collected_episodes = 0
|
|
|
|
try:
|
|
while collected_episodes < self.args.num_episodes and self.is_running:
|
|
if events["exit_early"]:
|
|
self.update_signal.emit("操作被用户终止.\n")
|
|
break
|
|
|
|
if events["record_start"]:
|
|
# 重置事件状态,开始新的录制
|
|
events["record_start"] = False
|
|
events["save_data"] = False
|
|
events["discard_data"] = False
|
|
|
|
self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n")
|
|
timesteps, actions = self.ros_operator.process()
|
|
self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n")
|
|
|
|
# 等待用户决定保存或丢弃
|
|
while self.is_running:
|
|
if events["save_data"]:
|
|
events["save_data"] = False
|
|
|
|
if len(actions) < self.args.max_timesteps:
|
|
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
|
|
else:
|
|
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
|
|
save_data(self.args, timesteps, actions, dataset_path)
|
|
self.update_signal.emit(f"第 {episode_idx} 集成功保存到 {dataset_path}.\n")
|
|
episode_idx += 1
|
|
collected_episodes += 1
|
|
progress_percentage = int(collected_episodes * 100 / self.args.num_episodes)
|
|
self.progress_signal.emit(progress_percentage)
|
|
self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n")
|
|
break
|
|
|
|
if events["discard_data"]:
|
|
events["discard_data"] = False
|
|
self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n")
|
|
break
|
|
|
|
if events["exit_early"]:
|
|
self.update_signal.emit("操作被用户终止.\n")
|
|
self.is_running = False
|
|
break
|
|
|
|
time.sleep(0.1) # 减少CPU使用率
|
|
|
|
time.sleep(0.1) # 减少CPU使用率
|
|
|
|
if collected_episodes == self.args.num_episodes:
|
|
self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n")
|
|
self.finish_signal.emit(True, "全部数据集收集完成")
|
|
else:
|
|
self.finish_signal.emit(False, "数据收集未完成")
|
|
|
|
finally:
|
|
# 确保监听器被清理
|
|
if listener:
|
|
listener.stop()
|
|
self.update_signal.emit("键盘监听器已停止\n")
|
|
|
|
except Exception as e:
|
|
self.update_signal.emit(f"错误: {str(e)}\n")
|
|
self.finish_signal.emit(False, str(e))
|
|
|
|
def stop(self):
|
|
self.is_running = False
|
|
self.wait()
|
|
|
|
class AlohaDataCollectionGUI(QMainWindow):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.setWindowTitle("ALOHA 数据收集工具")
|
|
self.setGeometry(100, 100, 800, 700)
|
|
|
|
# 主组件
|
|
self.central_widget = QWidget()
|
|
self.setCentralWidget(self.central_widget)
|
|
self.main_layout = QVBoxLayout(self.central_widget)
|
|
|
|
# 创建选项卡
|
|
self.tab_widget = QTabWidget()
|
|
self.main_layout.addWidget(self.tab_widget)
|
|
|
|
# 创建配置选项卡
|
|
self.config_tab = QWidget()
|
|
self.tab_widget.addTab(self.config_tab, "配置")
|
|
|
|
# 创建数据收集选项卡
|
|
self.collection_tab = QWidget()
|
|
self.tab_widget.addTab(self.collection_tab, "数据收集")
|
|
|
|
self.setup_config_tab()
|
|
self.setup_collection_tab()
|
|
|
|
# 初始化数据收集线程
|
|
self.collection_thread = None
|
|
|
|
def setup_config_tab(self):
|
|
config_layout = QVBoxLayout(self.config_tab)
|
|
|
|
# 基本配置组
|
|
basic_group = QGroupBox("基本配置")
|
|
basic_form = QFormLayout()
|
|
|
|
self.dataset_dir = QLineEdit("./data")
|
|
self.browse_button = QPushButton("浏览")
|
|
self.browse_button.clicked.connect(self.browse_dataset_dir)
|
|
|
|
dir_layout = QHBoxLayout()
|
|
dir_layout.addWidget(self.dataset_dir)
|
|
dir_layout.addWidget(self.browse_button)
|
|
|
|
self.task_name = QLineEdit("aloha_mobile_dummy")
|
|
self.episode_idx = QSpinBox()
|
|
self.episode_idx.setRange(0, 1000)
|
|
self.max_timesteps = QSpinBox()
|
|
self.max_timesteps.setRange(1, 10000)
|
|
self.max_timesteps.setValue(500)
|
|
self.num_episodes = QSpinBox()
|
|
self.num_episodes.setRange(1, 100)
|
|
self.num_episodes.setValue(1)
|
|
self.frame_rate = QSpinBox()
|
|
self.frame_rate.setRange(1, 120)
|
|
self.frame_rate.setValue(30)
|
|
|
|
basic_form.addRow("数据集目录:", dir_layout)
|
|
basic_form.addRow("任务名称:", self.task_name)
|
|
basic_form.addRow("起始集索引:", self.episode_idx)
|
|
basic_form.addRow("最大时间步:", self.max_timesteps)
|
|
basic_form.addRow("集数:", self.num_episodes)
|
|
basic_form.addRow("帧率:", self.frame_rate)
|
|
|
|
basic_group.setLayout(basic_form)
|
|
config_layout.addWidget(basic_group)
|
|
|
|
# 相机话题组
|
|
camera_group = QGroupBox("相机话题")
|
|
camera_form = QFormLayout()
|
|
|
|
self.img_front_topic = QLineEdit('/camera_f/color/image_raw')
|
|
self.img_left_topic = QLineEdit('/camera_l/color/image_raw')
|
|
self.img_right_topic = QLineEdit('/camera_r/color/image_raw')
|
|
|
|
self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw')
|
|
self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw')
|
|
self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw')
|
|
|
|
self.use_depth_image = QCheckBox("使用深度图像")
|
|
|
|
camera_form.addRow("前置相机:", self.img_front_topic)
|
|
camera_form.addRow("左腕相机:", self.img_left_topic)
|
|
camera_form.addRow("右腕相机:", self.img_right_topic)
|
|
camera_form.addRow("前置深度:", self.img_front_depth_topic)
|
|
camera_form.addRow("左腕深度:", self.img_left_depth_topic)
|
|
camera_form.addRow("右腕深度:", self.img_right_depth_topic)
|
|
camera_form.addRow("", self.use_depth_image)
|
|
|
|
camera_group.setLayout(camera_form)
|
|
config_layout.addWidget(camera_group)
|
|
|
|
# 机器人话题组
|
|
robot_group = QGroupBox("机器人话题")
|
|
robot_form = QFormLayout()
|
|
|
|
self.master_arm_left_topic = QLineEdit('/master/joint_left')
|
|
self.master_arm_right_topic = QLineEdit('/master/joint_right')
|
|
self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left')
|
|
self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right')
|
|
self.robot_base_topic = QLineEdit('/odom')
|
|
self.use_robot_base = QCheckBox("使用机器人底盘")
|
|
|
|
robot_form.addRow("主左臂:", self.master_arm_left_topic)
|
|
robot_form.addRow("主右臂:", self.master_arm_right_topic)
|
|
robot_form.addRow("从左臂:", self.puppet_arm_left_topic)
|
|
robot_form.addRow("从右臂:", self.puppet_arm_right_topic)
|
|
robot_form.addRow("底盘:", self.robot_base_topic)
|
|
robot_form.addRow("", self.use_robot_base)
|
|
|
|
robot_group.setLayout(robot_form)
|
|
config_layout.addWidget(robot_group)
|
|
|
|
# 相机名称配置
|
|
camera_names_group = QGroupBox("相机名称")
|
|
camera_names_layout = QVBoxLayout()
|
|
|
|
self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
|
self.camera_checkboxes = {}
|
|
|
|
for cam_name in self.camera_names:
|
|
self.camera_checkboxes[cam_name] = QCheckBox(cam_name)
|
|
self.camera_checkboxes[cam_name].setChecked(True)
|
|
camera_names_layout.addWidget(self.camera_checkboxes[cam_name])
|
|
|
|
camera_names_group.setLayout(camera_names_layout)
|
|
config_layout.addWidget(camera_names_group)
|
|
|
|
# 保存配置按钮
|
|
self.save_config_button = QPushButton("保存配置")
|
|
self.save_config_button.clicked.connect(self.save_config)
|
|
config_layout.addWidget(self.save_config_button)
|
|
|
|
def setup_collection_tab(self):
|
|
collection_layout = QVBoxLayout(self.collection_tab)
|
|
|
|
# 当前配置展示
|
|
config_group = QGroupBox("当前配置")
|
|
self.config_text = QTextEdit()
|
|
self.config_text.setReadOnly(True)
|
|
config_layout = QVBoxLayout()
|
|
config_layout.addWidget(self.config_text)
|
|
config_group.setLayout(config_layout)
|
|
collection_layout.addWidget(config_group)
|
|
|
|
# 操作按钮
|
|
buttons_layout = QHBoxLayout()
|
|
|
|
self.start_button = QPushButton("开始收集")
|
|
self.start_button.setIcon(QIcon.fromTheme("media-playback-start"))
|
|
self.start_button.clicked.connect(self.start_collection)
|
|
|
|
self.stop_button = QPushButton("停止")
|
|
self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop"))
|
|
self.stop_button.setEnabled(False)
|
|
self.stop_button.clicked.connect(self.stop_collection)
|
|
|
|
buttons_layout.addWidget(self.start_button)
|
|
buttons_layout.addWidget(self.stop_button)
|
|
collection_layout.addLayout(buttons_layout)
|
|
|
|
# 进度条
|
|
self.progress_bar = QProgressBar()
|
|
self.progress_bar.setValue(0)
|
|
collection_layout.addWidget(self.progress_bar)
|
|
|
|
# 日志输出
|
|
log_group = QGroupBox("操作日志")
|
|
self.log_text = QTextEdit()
|
|
self.log_text.setReadOnly(True)
|
|
log_layout = QVBoxLayout()
|
|
log_layout.addWidget(self.log_text)
|
|
log_group.setLayout(log_layout)
|
|
collection_layout.addWidget(log_group)
|
|
|
|
def browse_dataset_dir(self):
|
|
directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text())
|
|
if directory:
|
|
self.dataset_dir.setText(directory)
|
|
|
|
def save_config(self):
|
|
# 更新配置显示
|
|
selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()]
|
|
|
|
config_text = f"""
|
|
任务名称: {self.task_name.text()}
|
|
数据集目录: {self.dataset_dir.text()}
|
|
起始集索引: {self.episode_idx.value()}
|
|
最大时间步: {self.max_timesteps.value()}
|
|
集数: {self.num_episodes.value()}
|
|
帧率: {self.frame_rate.value()}
|
|
使用深度图像: {"是" if self.use_depth_image.isChecked() else "否"}
|
|
使用机器人底盘: {"是" if self.use_robot_base.isChecked() else "否"}
|
|
相机: {', '.join(selected_cameras)}
|
|
"""
|
|
|
|
self.config_text.setText(config_text)
|
|
self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡
|
|
|
|
QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集")
|
|
|
|
def start_collection(self):
|
|
if not self.task_name.text():
|
|
QMessageBox.warning(self, "配置错误", "请输入有效的任务名称")
|
|
return
|
|
|
|
# 构建参数
|
|
args = argparse.Namespace(
|
|
dataset_dir=self.dataset_dir.text(),
|
|
task_name=self.task_name.text(),
|
|
episode_idx=self.episode_idx.value(),
|
|
max_timesteps=self.max_timesteps.value(),
|
|
num_episodes=self.num_episodes.value(),
|
|
camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()],
|
|
img_front_topic=self.img_front_topic.text(),
|
|
img_left_topic=self.img_left_topic.text(),
|
|
img_right_topic=self.img_right_topic.text(),
|
|
img_front_depth_topic=self.img_front_depth_topic.text(),
|
|
img_left_depth_topic=self.img_left_depth_topic.text(),
|
|
img_right_depth_topic=self.img_right_depth_topic.text(),
|
|
master_arm_left_topic=self.master_arm_left_topic.text(),
|
|
master_arm_right_topic=self.master_arm_right_topic.text(),
|
|
puppet_arm_left_topic=self.puppet_arm_left_topic.text(),
|
|
puppet_arm_right_topic=self.puppet_arm_right_topic.text(),
|
|
robot_base_topic=self.robot_base_topic.text(),
|
|
use_robot_base=self.use_robot_base.isChecked(),
|
|
use_depth_image=self.use_depth_image.isChecked(),
|
|
frame_rate=self.frame_rate.value()
|
|
)
|
|
|
|
# 更新UI状态
|
|
self.start_button.setEnabled(False)
|
|
self.stop_button.setEnabled(True)
|
|
self.progress_bar.setValue(0)
|
|
self.log_text.clear()
|
|
self.log_text.append("正在初始化数据收集...\n")
|
|
|
|
# 创建并启动线程
|
|
self.collection_thread = DataCollectionThread(args)
|
|
self.collection_thread.update_signal.connect(self.update_log)
|
|
self.collection_thread.progress_signal.connect(self.update_progress)
|
|
self.collection_thread.finish_signal.connect(self.collection_finished)
|
|
self.collection_thread.start()
|
|
|
|
def stop_collection(self):
|
|
if self.collection_thread and self.collection_thread.isRunning():
|
|
self.log_text.append("正在停止数据收集...\n")
|
|
self.collection_thread.stop()
|
|
self.collection_thread = None
|
|
|
|
self.start_button.setEnabled(True)
|
|
self.stop_button.setEnabled(False)
|
|
|
|
@pyqtSlot(str)
|
|
def update_log(self, message):
|
|
self.log_text.append(message)
|
|
# 自动滚动到底部
|
|
cursor = self.log_text.textCursor()
|
|
cursor.movePosition(QTextCursor.End)
|
|
self.log_text.setTextCursor(cursor)
|
|
|
|
@pyqtSlot(int)
|
|
def update_progress(self, value):
|
|
self.progress_bar.setValue(value)
|
|
|
|
@pyqtSlot(bool, str)
|
|
def collection_finished(self, success, message):
|
|
self.start_button.setEnabled(True)
|
|
self.stop_button.setEnabled(False)
|
|
|
|
if success:
|
|
QMessageBox.information(self, "完成", message)
|
|
else:
|
|
QMessageBox.warning(self, "出错", f"数据收集失败: {message}")
|
|
|
|
# 更新episode_idx值
|
|
if success and self.num_episodes.value() > 0:
|
|
self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value())
|
|
|
|
def main():
|
|
app = QApplication(sys.argv)
|
|
window = AlohaDataCollectionGUI()
|
|
window.show()
|
|
sys.exit(app.exec_())
|
|
|
|
if __name__ == "__main__":
|
|
main()
|