add convert aloha 2 lerobot

This commit is contained in:
2025-04-20 21:50:20 +08:00
parent 722de584d2
commit 25fb9c0d33
13 changed files with 753 additions and 452 deletions

View File

@@ -1,416 +1,403 @@
import os
import sys
import time
import argparse
from aloha_mobile import AlohaRobotRos
from utils import save_data, init_keyboard_listener
import os
import yaml
from collect_data import main
from types import SimpleNamespace
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox,
QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog,
QMessageBox, QProgressBar, QComboBox)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot
from PyQt5.QtGui import QFont, QIcon, QTextCursor
QLabel, QLineEdit, QPushButton, QSpinBox, QCheckBox,
QGroupBox, QTabWidget, QMessageBox, QFileDialog)
from PyQt5.QtCore import Qt
class DataCollectionThread(QThread):
"""处理数据收集的线程"""
update_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
finish_signal = pyqtSignal(bool, str)
def __init__(self, args, parent=None):
super(DataCollectionThread, self).__init__(parent)
self.args = args
self.is_running = True
self.ros_operator = None
def run(self):
try:
self.update_signal.emit("正在初始化ROS操作...\n")
self.ros_operator = AlohaRobotRos(self.args)
dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name)
os.makedirs(dataset_dir, exist_ok=True)
# 单集收集模式
if self.args.num_episodes == 1:
self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n")
timesteps, actions = self.ros_operator.process()
if len(actions) < self.args.max_timesteps:
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步")
return
dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}")
save_data(self.args, timesteps, actions, dataset_path)
self.update_signal.emit(f"{self.args.episode_idx} 集成功保存到 {dataset_path}.\n")
self.finish_signal.emit(True, "数据收集完成")
# 多集收集模式
else:
self.update_signal.emit("""
键盘控制:
← 左箭头: 开始录制
→ 右箭头: 保存当前数据
↓ 下箭头: 丢弃当前数据
ESC: 退出程序
""")
# 初始化键盘监听器
listener, events = init_keyboard_listener()
episode_idx = self.args.episode_idx
collected_episodes = 0
try:
while collected_episodes < self.args.num_episodes and self.is_running:
if events["exit_early"]:
self.update_signal.emit("操作被用户终止.\n")
break
if events["record_start"]:
# 重置事件状态,开始新的录制
events["record_start"] = False
events["save_data"] = False
events["discard_data"] = False
self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n")
timesteps, actions = self.ros_operator.process()
self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n")
# 等待用户决定保存或丢弃
while self.is_running:
if events["save_data"]:
events["save_data"] = False
if len(actions) < self.args.max_timesteps:
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
else:
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
save_data(self.args, timesteps, actions, dataset_path)
self.update_signal.emit(f"{episode_idx} 集成功保存到 {dataset_path}.\n")
episode_idx += 1
collected_episodes += 1
progress_percentage = int(collected_episodes * 100 / self.args.num_episodes)
self.progress_signal.emit(progress_percentage)
self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n")
break
if events["discard_data"]:
events["discard_data"] = False
self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n")
break
if events["exit_early"]:
self.update_signal.emit("操作被用户终止.\n")
self.is_running = False
break
time.sleep(0.1) # 减少CPU使用率
time.sleep(0.1) # 减少CPU使用率
if collected_episodes == self.args.num_episodes:
self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n")
self.finish_signal.emit(True, "全部数据集收集完成")
else:
self.finish_signal.emit(False, "数据收集未完成")
finally:
# 确保监听器被清理
if listener:
listener.stop()
self.update_signal.emit("键盘监听器已停止\n")
except Exception as e:
self.update_signal.emit(f"错误: {str(e)}\n")
self.finish_signal.emit(False, str(e))
def stop(self):
self.is_running = False
self.wait()
class AlohaDataCollectionGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("ALOHA 数据收集工具")
self.setGeometry(100, 100, 800, 700)
self.setWindowTitle("MindRobot-V1 Data Collection")
self.setGeometry(100, 100, 800, 600)
# 主组件
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.main_layout = QVBoxLayout(self.central_widget)
# 创建选项卡
self.tab_widget = QTabWidget()
self.main_layout.addWidget(self.tab_widget)
self.config_path = os.path.expanduser("/home/ubuntu/LYT/lerobot_aloha/collect_data/aloha.yaml")
self.create_ui()
self.setup_connections()
self.load_default_config()
# 创建配置选项卡
self.config_tab = QWidget()
self.tab_widget.addTab(self.config_tab, "配置")
def create_ui(self):
# Create tabs
self.tabs = QTabWidget()
self.main_layout.addWidget(self.tabs)
# 创建数据收集选项卡
self.collection_tab = QWidget()
self.tab_widget.addTab(self.collection_tab, "数据收集")
# General Settings Tab
self.general_tab = QWidget()
self.tabs.addTab(self.general_tab, "General Settings")
self.create_general_tab()
self.setup_config_tab()
self.setup_collection_tab()
# Camera Settings Tab
self.camera_tab = QWidget()
self.tabs.addTab(self.camera_tab, "Camera Settings")
self.create_camera_tab()
# 初始化数据收集线程
self.collection_thread = None
# Arm Settings Tab
self.arm_tab = QWidget()
self.tabs.addTab(self.arm_tab, "Arm Settings")
self.create_arm_tab()
def setup_config_tab(self):
config_layout = QVBoxLayout(self.config_tab)
# Control Buttons
self.control_group = QGroupBox("Control")
control_layout = QHBoxLayout()
# 基本配置组
basic_group = QGroupBox("基本配置")
basic_form = QFormLayout()
self.load_config_button = QPushButton("Load Config")
self.save_config_button = QPushButton("Save Config")
self.start_button = QPushButton("Start Recording")
self.stop_button = QPushButton("Stop Recording")
self.exit_button = QPushButton("Exit")
self.dataset_dir = QLineEdit("./data")
self.browse_button = QPushButton("浏览")
self.browse_button.clicked.connect(self.browse_dataset_dir)
control_layout.addWidget(self.load_config_button)
control_layout.addWidget(self.save_config_button)
control_layout.addWidget(self.start_button)
control_layout.addWidget(self.stop_button)
control_layout.addWidget(self.exit_button)
dir_layout = QHBoxLayout()
dir_layout.addWidget(self.dataset_dir)
dir_layout.addWidget(self.browse_button)
self.control_group.setLayout(control_layout)
self.main_layout.addWidget(self.control_group)
self.task_name = QLineEdit("aloha_mobile_dummy")
self.episode_idx = QSpinBox()
self.episode_idx.setRange(0, 1000)
self.max_timesteps = QSpinBox()
self.max_timesteps.setRange(1, 10000)
self.max_timesteps.setValue(500)
self.num_episodes = QSpinBox()
self.num_episodes.setRange(1, 100)
self.num_episodes.setValue(1)
self.frame_rate = QSpinBox()
self.frame_rate.setRange(1, 120)
self.frame_rate.setValue(30)
def create_general_tab(self):
layout = QVBoxLayout(self.general_tab)
basic_form.addRow("数据集目录:", dir_layout)
basic_form.addRow("任务名称:", self.task_name)
basic_form.addRow("起始集索引:", self.episode_idx)
basic_form.addRow("最大时间步:", self.max_timesteps)
basic_form.addRow("集数:", self.num_episodes)
basic_form.addRow("帧率:", self.frame_rate)
# Config File Path
config_group = QGroupBox("Configuration File")
config_layout = QHBoxLayout()
basic_group.setLayout(basic_form)
config_layout.addWidget(basic_group)
self.config_path_edit = QLineEdit(self.config_path)
self.browse_config_button = QPushButton("Browse...")
# 相机话题组
camera_group = QGroupBox("相机话题")
camera_form = QFormLayout()
config_layout.addWidget(QLabel("Config File:"))
config_layout.addWidget(self.config_path_edit)
config_layout.addWidget(self.browse_config_button)
self.img_front_topic = QLineEdit('/camera_f/color/image_raw')
self.img_left_topic = QLineEdit('/camera_l/color/image_raw')
self.img_right_topic = QLineEdit('/camera_r/color/image_raw')
self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw')
self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw')
self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw')
self.use_depth_image = QCheckBox("使用深度图像")
camera_form.addRow("前置相机:", self.img_front_topic)
camera_form.addRow("左腕相机:", self.img_left_topic)
camera_form.addRow("右腕相机:", self.img_right_topic)
camera_form.addRow("前置深度:", self.img_front_depth_topic)
camera_form.addRow("左腕深度:", self.img_left_depth_topic)
camera_form.addRow("右腕深度:", self.img_right_depth_topic)
camera_form.addRow("", self.use_depth_image)
camera_group.setLayout(camera_form)
config_layout.addWidget(camera_group)
# 机器人话题组
robot_group = QGroupBox("机器人话题")
robot_form = QFormLayout()
self.master_arm_left_topic = QLineEdit('/master/joint_left')
self.master_arm_right_topic = QLineEdit('/master/joint_right')
self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left')
self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right')
self.robot_base_topic = QLineEdit('/odom')
self.use_robot_base = QCheckBox("使用机器人底盘")
robot_form.addRow("主左臂:", self.master_arm_left_topic)
robot_form.addRow("主右臂:", self.master_arm_right_topic)
robot_form.addRow("从左臂:", self.puppet_arm_left_topic)
robot_form.addRow("从右臂:", self.puppet_arm_right_topic)
robot_form.addRow("底盘:", self.robot_base_topic)
robot_form.addRow("", self.use_robot_base)
robot_group.setLayout(robot_form)
config_layout.addWidget(robot_group)
# 相机名称配置
camera_names_group = QGroupBox("相机名称")
camera_names_layout = QVBoxLayout()
self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
self.camera_checkboxes = {}
for cam_name in self.camera_names:
self.camera_checkboxes[cam_name] = QCheckBox(cam_name)
self.camera_checkboxes[cam_name].setChecked(True)
camera_names_layout.addWidget(self.camera_checkboxes[cam_name])
camera_names_group.setLayout(camera_names_layout)
config_layout.addWidget(camera_names_group)
# 保存配置按钮
self.save_config_button = QPushButton("保存配置")
self.save_config_button.clicked.connect(self.save_config)
config_layout.addWidget(self.save_config_button)
def setup_collection_tab(self):
collection_layout = QVBoxLayout(self.collection_tab)
# 当前配置展示
config_group = QGroupBox("当前配置")
self.config_text = QTextEdit()
self.config_text.setReadOnly(True)
config_layout = QVBoxLayout()
config_layout.addWidget(self.config_text)
config_group.setLayout(config_layout)
collection_layout.addWidget(config_group)
layout.addWidget(config_group)
# 操作按钮
buttons_layout = QHBoxLayout()
# Dataset Directory
dir_group = QGroupBox("Dataset Directory")
dir_layout = QHBoxLayout()
self.start_button = QPushButton("开始收集")
self.start_button.setIcon(QIcon.fromTheme("media-playback-start"))
self.start_button.clicked.connect(self.start_collection)
self.dataset_dir_edit = QLineEdit()
self.browse_dir_button = QPushButton("Browse...")
self.stop_button = QPushButton("停止")
self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop"))
self.stop_button.setEnabled(False)
self.stop_button.clicked.connect(self.stop_collection)
dir_layout.addWidget(QLabel("Dataset Directory:"))
dir_layout.addWidget(self.dataset_dir_edit)
dir_layout.addWidget(self.browse_dir_button)
buttons_layout.addWidget(self.start_button)
buttons_layout.addWidget(self.stop_button)
collection_layout.addLayout(buttons_layout)
dir_group.setLayout(dir_layout)
layout.addWidget(dir_group)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setValue(0)
collection_layout.addWidget(self.progress_bar)
# Task Settings
task_group = QGroupBox("Task Settings")
task_layout = QVBoxLayout()
# 日志输出
log_group = QGroupBox("操作日志")
self.log_text = QTextEdit()
self.log_text.setReadOnly(True)
log_layout = QVBoxLayout()
log_layout.addWidget(self.log_text)
log_group.setLayout(log_layout)
collection_layout.addWidget(log_group)
self.task_name_edit = QLineEdit()
self.episode_idx_spin = QSpinBox()
self.episode_idx_spin.setRange(0, 9999)
self.max_timesteps_spin = QSpinBox()
self.max_timesteps_spin.setRange(1, 10000)
self.num_episodes_spin = QSpinBox()
self.num_episodes_spin.setRange(1, 1000)
self.frame_rate_spin = QSpinBox()
self.frame_rate_spin.setRange(1, 60)
task_layout.addWidget(QLabel("Task Name:"))
task_layout.addWidget(self.task_name_edit)
task_layout.addWidget(QLabel("Episode Index:"))
task_layout.addWidget(self.episode_idx_spin)
task_layout.addWidget(QLabel("Max Timesteps:"))
task_layout.addWidget(self.max_timesteps_spin)
task_layout.addWidget(QLabel("Number of Episodes:"))
task_layout.addWidget(self.num_episodes_spin)
task_layout.addWidget(QLabel("Frame Rate:"))
task_layout.addWidget(self.frame_rate_spin)
task_group.setLayout(task_layout)
layout.addWidget(task_group)
# Options
options_group = QGroupBox("Options")
options_layout = QVBoxLayout()
self.use_robot_base_check = QCheckBox("Use Robot Base")
self.use_depth_image_check = QCheckBox("Use Depth Image")
options_layout.addWidget(self.use_robot_base_check)
options_layout.addWidget(self.use_depth_image_check)
options_group.setLayout(options_layout)
layout.addWidget(options_group)
layout.addStretch()
def create_camera_tab(self):
layout = QVBoxLayout(self.camera_tab)
# Color Image Topics
color_group = QGroupBox("Color Image Topics")
color_layout = QVBoxLayout()
self.img_front_topic_edit = QLineEdit()
self.img_left_topic_edit = QLineEdit()
self.img_right_topic_edit = QLineEdit()
color_layout.addWidget(QLabel("Front Camera Topic:"))
color_layout.addWidget(self.img_front_topic_edit)
color_layout.addWidget(QLabel("Left Camera Topic:"))
color_layout.addWidget(self.img_left_topic_edit)
color_layout.addWidget(QLabel("Right Camera Topic:"))
color_layout.addWidget(self.img_right_topic_edit)
color_group.setLayout(color_layout)
layout.addWidget(color_group)
# Depth Image Topics
depth_group = QGroupBox("Depth Image Topics")
depth_layout = QVBoxLayout()
self.img_front_depth_topic_edit = QLineEdit()
self.img_left_depth_topic_edit = QLineEdit()
self.img_right_depth_topic_edit = QLineEdit()
depth_layout.addWidget(QLabel("Front Depth Topic:"))
depth_layout.addWidget(self.img_front_depth_topic_edit)
depth_layout.addWidget(QLabel("Left Depth Topic:"))
depth_layout.addWidget(self.img_left_depth_topic_edit)
depth_layout.addWidget(QLabel("Right Depth Topic:"))
depth_layout.addWidget(self.img_right_depth_topic_edit)
depth_group.setLayout(depth_layout)
layout.addWidget(depth_group)
layout.addStretch()
def create_arm_tab(self):
layout = QVBoxLayout(self.arm_tab)
# Master Arm Topics
master_group = QGroupBox("Master Arm Topics")
master_layout = QVBoxLayout()
self.master_arm_left_topic_edit = QLineEdit()
self.master_arm_right_topic_edit = QLineEdit()
master_layout.addWidget(QLabel("Master Left Arm Topic:"))
master_layout.addWidget(self.master_arm_left_topic_edit)
master_layout.addWidget(QLabel("Master Right Arm Topic:"))
master_layout.addWidget(self.master_arm_right_topic_edit)
master_group.setLayout(master_layout)
layout.addWidget(master_group)
# Puppet Arm Topics
puppet_group = QGroupBox("Puppet Arm Topics")
puppet_layout = QVBoxLayout()
self.puppet_arm_left_topic_edit = QLineEdit()
self.puppet_arm_right_topic_edit = QLineEdit()
puppet_layout.addWidget(QLabel("Puppet Left Arm Topic:"))
puppet_layout.addWidget(self.puppet_arm_left_topic_edit)
puppet_layout.addWidget(QLabel("Puppet Right Arm Topic:"))
puppet_layout.addWidget(self.puppet_arm_right_topic_edit)
puppet_group.setLayout(puppet_layout)
layout.addWidget(puppet_group)
# Robot Base Topic
base_group = QGroupBox("Robot Base Topic")
base_layout = QVBoxLayout()
self.robot_base_topic_edit = QLineEdit()
base_layout.addWidget(QLabel("Robot Base Topic:"))
base_layout.addWidget(self.robot_base_topic_edit)
base_group.setLayout(base_layout)
layout.addWidget(base_group)
layout.addStretch()
def setup_connections(self):
self.load_config_button.clicked.connect(self.load_config)
self.save_config_button.clicked.connect(self.save_config)
self.browse_config_button.clicked.connect(self.browse_config_file)
self.browse_dir_button.clicked.connect(self.browse_dataset_dir)
self.start_button.clicked.connect(self.start_recording)
self.stop_button.clicked.connect(self.stop_recording)
self.exit_button.clicked.connect(self.close)
def load_default_config(self):
default_config = {
'dataset_dir': '/home/ubuntu/LYT/lerobot_aloha/datasets/3camera',
'task_name': 'aloha_mobile_dummy',
'episode_idx': 0,
'max_timesteps': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
'num_episodes': 50,
'img_front_topic': '/camera_f/color/image_raw',
'img_left_topic': '/camera_l/color/image_raw',
'img_right_topic': '/camera_r/color/image_raw',
'img_front_depth_topic': '/camera_f/depth/image_raw',
'img_left_depth_topic': '/camera_l/depth/image_raw',
'img_right_depth_topic': '/camera_r/depth/image_raw',
'master_arm_left_topic': '/master/joint_left',
'master_arm_right_topic': '/master/joint_right',
'puppet_arm_left_topic': '/puppet/joint_left',
'puppet_arm_right_topic': '/puppet/joint_right',
'robot_base_topic': '/odom',
'use_robot_base': False,
'use_depth_image': False,
'frame_rate': 30
}
# Update UI with default values
self.update_ui_from_config(default_config)
def update_ui_from_config(self, config):
"""Update UI elements from a config dictionary"""
self.dataset_dir_edit.setText(config.get('dataset_dir', ''))
self.task_name_edit.setText(config.get('task_name', ''))
self.episode_idx_spin.setValue(config.get('episode_idx', 0))
self.max_timesteps_spin.setValue(config.get('max_timesteps', 500))
self.num_episodes_spin.setValue(config.get('num_episodes', 1))
self.frame_rate_spin.setValue(config.get('frame_rate', 30))
self.img_front_topic_edit.setText(config.get('img_front_topic', ''))
self.img_left_topic_edit.setText(config.get('img_left_topic', ''))
self.img_right_topic_edit.setText(config.get('img_right_topic', ''))
self.img_front_depth_topic_edit.setText(config.get('img_front_depth_topic', ''))
self.img_left_depth_topic_edit.setText(config.get('img_left_depth_topic', ''))
self.img_right_depth_topic_edit.setText(config.get('img_right_depth_topic', ''))
self.master_arm_left_topic_edit.setText(config.get('master_arm_left_topic', ''))
self.master_arm_right_topic_edit.setText(config.get('master_arm_right_topic', ''))
self.puppet_arm_left_topic_edit.setText(config.get('puppet_arm_left_topic', ''))
self.puppet_arm_right_topic_edit.setText(config.get('puppet_arm_right_topic', ''))
self.robot_base_topic_edit.setText(config.get('robot_base_topic', ''))
self.use_robot_base_check.setChecked(config.get('use_robot_base', False))
self.use_depth_image_check.setChecked(config.get('use_depth_image', False))
def get_config_from_ui(self):
"""Get current UI values as a config dictionary"""
config = {
'dataset_dir': self.dataset_dir_edit.text(),
'task_name': self.task_name_edit.text(),
'episode_idx': self.episode_idx_spin.value(),
'max_timesteps': self.max_timesteps_spin.value(),
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
'num_episodes': self.num_episodes_spin.value(),
'img_front_topic': self.img_front_topic_edit.text(),
'img_left_topic': self.img_left_topic_edit.text(),
'img_right_topic': self.img_right_topic_edit.text(),
'img_front_depth_topic': self.img_front_depth_topic_edit.text(),
'img_left_depth_topic': self.img_left_depth_topic_edit.text(),
'img_right_depth_topic': self.img_right_depth_topic_edit.text(),
'master_arm_left_topic': self.master_arm_left_topic_edit.text(),
'master_arm_right_topic': self.master_arm_right_topic_edit.text(),
'puppet_arm_left_topic': self.puppet_arm_left_topic_edit.text(),
'puppet_arm_right_topic': self.puppet_arm_right_topic_edit.text(),
'robot_base_topic': self.robot_base_topic_edit.text(),
'use_robot_base': self.use_robot_base_check.isChecked(),
'use_depth_image': self.use_depth_image_check.isChecked(),
'frame_rate': self.frame_rate_spin.value()
}
return config
def browse_config_file(self):
file_path, _ = QFileDialog.getOpenFileName(
self, "Select Config File", "", "YAML Files (*.yaml *.yml)"
)
if file_path:
self.config_path_edit.setText(file_path)
self.load_config()
def browse_dataset_dir(self):
directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text())
if directory:
self.dataset_dir.setText(directory)
def save_config(self):
# 更新配置显示
selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()]
config_text = f"""
任务名称: {self.task_name.text()}
数据集目录: {self.dataset_dir.text()}
起始集索引: {self.episode_idx.value()}
最大时间步: {self.max_timesteps.value()}
集数: {self.num_episodes.value()}
帧率: {self.frame_rate.value()}
使用深度图像: {"" if self.use_depth_image.isChecked() else ""}
使用机器人底盘: {"" if self.use_robot_base.isChecked() else ""}
相机: {', '.join(selected_cameras)}
"""
self.config_text.setText(config_text)
self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡
QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集")
def start_collection(self):
if not self.task_name.text():
QMessageBox.warning(self, "配置错误", "请输入有效的任务名称")
dir_path = QFileDialog.getExistingDirectory(
self, "Select Dataset Directory"
)
if dir_path:
self.dataset_dir_edit.setText(dir_path)
def load_config(self):
config_path = self.config_path_edit.text()
if not os.path.exists(config_path):
QMessageBox.warning(self, "Warning", f"Config file not found: {config_path}")
return
# 构建参数
args = argparse.Namespace(
dataset_dir=self.dataset_dir.text(),
task_name=self.task_name.text(),
episode_idx=self.episode_idx.value(),
max_timesteps=self.max_timesteps.value(),
num_episodes=self.num_episodes.value(),
camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()],
img_front_topic=self.img_front_topic.text(),
img_left_topic=self.img_left_topic.text(),
img_right_topic=self.img_right_topic.text(),
img_front_depth_topic=self.img_front_depth_topic.text(),
img_left_depth_topic=self.img_left_depth_topic.text(),
img_right_depth_topic=self.img_right_depth_topic.text(),
master_arm_left_topic=self.master_arm_left_topic.text(),
master_arm_right_topic=self.master_arm_right_topic.text(),
puppet_arm_left_topic=self.puppet_arm_left_topic.text(),
puppet_arm_right_topic=self.puppet_arm_right_topic.text(),
robot_base_topic=self.robot_base_topic.text(),
use_robot_base=self.use_robot_base.isChecked(),
use_depth_image=self.use_depth_image.isChecked(),
frame_rate=self.frame_rate.value()
)
# 更新UI状态
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.progress_bar.setValue(0)
self.log_text.clear()
self.log_text.append("正在初始化数据收集...\n")
# 创建并启动线程
self.collection_thread = DataCollectionThread(args)
self.collection_thread.update_signal.connect(self.update_log)
self.collection_thread.progress_signal.connect(self.update_progress)
self.collection_thread.finish_signal.connect(self.collection_finished)
self.collection_thread.start()
def stop_collection(self):
if self.collection_thread and self.collection_thread.isRunning():
self.log_text.append("正在停止数据收集...\n")
self.collection_thread.stop()
self.collection_thread = None
try:
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
self.update_ui_from_config(config)
self.statusBar().showMessage(f"Config loaded from {config_path}", 3000)
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to load config: {str(e)}")
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
@pyqtSlot(str)
def update_log(self, message):
self.log_text.append(message)
# 自动滚动到底部
cursor = self.log_text.textCursor()
cursor.movePosition(QTextCursor.End)
self.log_text.setTextCursor(cursor)
@pyqtSlot(int)
def update_progress(self, value):
self.progress_bar.setValue(value)
@pyqtSlot(bool, str)
def collection_finished(self, success, message):
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
if success:
QMessageBox.information(self, "完成", message)
else:
QMessageBox.warning(self, "出错", f"数据收集失败: {message}")
def save_config(self):
config_path = self.config_path_edit.text()
if not config_path:
QMessageBox.warning(self, "Warning", "Please specify a config file path")
return
# 更新episode_idx值
if success and self.num_episodes.value() > 0:
self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value())
try:
config = self.get_config_from_ui()
with open(config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
self.statusBar().showMessage(f"Config saved to {config_path}", 3000)
except Exception as e:
QMessageBox.critical(self, "Error", f"Failed to save config: {str(e)}")
def start_recording(self):
try:
# Save current config to a temporary file
temp_config_path = "/tmp/aloha_temp_config.yaml"
config = self.get_config_from_ui()
# Validate inputs
if not config['dataset_dir']:
QMessageBox.warning(self, "Warning", "Dataset directory cannot be empty!")
return
if not config['task_name']:
QMessageBox.warning(self, "Warning", "Task name cannot be empty!")
return
with open(temp_config_path, 'w') as f:
yaml.dump(config, f, default_flow_style=False)
self.statusBar().showMessage("Recording started...")
# Start recording with the temporary config file
exit_code = main(temp_config_path)
if exit_code == 0:
self.statusBar().showMessage("Recording completed successfully!", 5000)
else:
self.statusBar().showMessage("Recording completed with errors", 5000)
except Exception as e:
QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}")
self.statusBar().showMessage("Recording failed", 5000)
def stop_recording(self):
# In a real application, this would signal the recording thread to stop
self.statusBar().showMessage("Recording stopped", 5000)
QMessageBox.information(self, "Info", "Stop recording requested. This would stop the recording in a real implementation.")
def main():
if __name__ == "__main__":
app = QApplication(sys.argv)
window = AlohaDataCollectionGUI()
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()