add convert aloha 2 lerobot
This commit is contained in:
@@ -1,416 +1,403 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from aloha_mobile import AlohaRobotRos
|
||||
from utils import save_data, init_keyboard_listener
|
||||
import os
|
||||
import yaml
|
||||
from collect_data import main
|
||||
from types import SimpleNamespace
|
||||
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
||||
QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox,
|
||||
QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog,
|
||||
QMessageBox, QProgressBar, QComboBox)
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot
|
||||
from PyQt5.QtGui import QFont, QIcon, QTextCursor
|
||||
QLabel, QLineEdit, QPushButton, QSpinBox, QCheckBox,
|
||||
QGroupBox, QTabWidget, QMessageBox, QFileDialog)
|
||||
from PyQt5.QtCore import Qt
|
||||
|
||||
|
||||
class DataCollectionThread(QThread):
|
||||
"""处理数据收集的线程"""
|
||||
update_signal = pyqtSignal(str)
|
||||
progress_signal = pyqtSignal(int)
|
||||
finish_signal = pyqtSignal(bool, str)
|
||||
|
||||
def __init__(self, args, parent=None):
|
||||
super(DataCollectionThread, self).__init__(parent)
|
||||
self.args = args
|
||||
self.is_running = True
|
||||
self.ros_operator = None
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.update_signal.emit("正在初始化ROS操作...\n")
|
||||
self.ros_operator = AlohaRobotRos(self.args)
|
||||
dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name)
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
# 单集收集模式
|
||||
if self.args.num_episodes == 1:
|
||||
self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n")
|
||||
timesteps, actions = self.ros_operator.process()
|
||||
|
||||
if len(actions) < self.args.max_timesteps:
|
||||
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
|
||||
self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步")
|
||||
return
|
||||
|
||||
dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}")
|
||||
save_data(self.args, timesteps, actions, dataset_path)
|
||||
self.update_signal.emit(f"第 {self.args.episode_idx} 集成功保存到 {dataset_path}.\n")
|
||||
self.finish_signal.emit(True, "数据收集完成")
|
||||
|
||||
# 多集收集模式
|
||||
else:
|
||||
self.update_signal.emit("""
|
||||
键盘控制:
|
||||
← 左箭头: 开始录制
|
||||
→ 右箭头: 保存当前数据
|
||||
↓ 下箭头: 丢弃当前数据
|
||||
ESC: 退出程序
|
||||
""")
|
||||
# 初始化键盘监听器
|
||||
listener, events = init_keyboard_listener()
|
||||
episode_idx = self.args.episode_idx
|
||||
collected_episodes = 0
|
||||
|
||||
try:
|
||||
while collected_episodes < self.args.num_episodes and self.is_running:
|
||||
if events["exit_early"]:
|
||||
self.update_signal.emit("操作被用户终止.\n")
|
||||
break
|
||||
|
||||
if events["record_start"]:
|
||||
# 重置事件状态,开始新的录制
|
||||
events["record_start"] = False
|
||||
events["save_data"] = False
|
||||
events["discard_data"] = False
|
||||
|
||||
self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n")
|
||||
timesteps, actions = self.ros_operator.process()
|
||||
self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n")
|
||||
|
||||
# 等待用户决定保存或丢弃
|
||||
while self.is_running:
|
||||
if events["save_data"]:
|
||||
events["save_data"] = False
|
||||
|
||||
if len(actions) < self.args.max_timesteps:
|
||||
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
|
||||
else:
|
||||
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
|
||||
save_data(self.args, timesteps, actions, dataset_path)
|
||||
self.update_signal.emit(f"第 {episode_idx} 集成功保存到 {dataset_path}.\n")
|
||||
episode_idx += 1
|
||||
collected_episodes += 1
|
||||
progress_percentage = int(collected_episodes * 100 / self.args.num_episodes)
|
||||
self.progress_signal.emit(progress_percentage)
|
||||
self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n")
|
||||
break
|
||||
|
||||
if events["discard_data"]:
|
||||
events["discard_data"] = False
|
||||
self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n")
|
||||
break
|
||||
|
||||
if events["exit_early"]:
|
||||
self.update_signal.emit("操作被用户终止.\n")
|
||||
self.is_running = False
|
||||
break
|
||||
|
||||
time.sleep(0.1) # 减少CPU使用率
|
||||
|
||||
time.sleep(0.1) # 减少CPU使用率
|
||||
|
||||
if collected_episodes == self.args.num_episodes:
|
||||
self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n")
|
||||
self.finish_signal.emit(True, "全部数据集收集完成")
|
||||
else:
|
||||
self.finish_signal.emit(False, "数据收集未完成")
|
||||
|
||||
finally:
|
||||
# 确保监听器被清理
|
||||
if listener:
|
||||
listener.stop()
|
||||
self.update_signal.emit("键盘监听器已停止\n")
|
||||
|
||||
except Exception as e:
|
||||
self.update_signal.emit(f"错误: {str(e)}\n")
|
||||
self.finish_signal.emit(False, str(e))
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
self.wait()
|
||||
|
||||
class AlohaDataCollectionGUI(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("ALOHA 数据收集工具")
|
||||
self.setGeometry(100, 100, 800, 700)
|
||||
self.setWindowTitle("MindRobot-V1 Data Collection")
|
||||
self.setGeometry(100, 100, 800, 600)
|
||||
|
||||
# 主组件
|
||||
self.central_widget = QWidget()
|
||||
self.setCentralWidget(self.central_widget)
|
||||
|
||||
self.main_layout = QVBoxLayout(self.central_widget)
|
||||
|
||||
# 创建选项卡
|
||||
self.tab_widget = QTabWidget()
|
||||
self.main_layout.addWidget(self.tab_widget)
|
||||
self.config_path = os.path.expanduser("/home/ubuntu/LYT/lerobot_aloha/collect_data/aloha.yaml")
|
||||
self.create_ui()
|
||||
self.setup_connections()
|
||||
self.load_default_config()
|
||||
|
||||
# 创建配置选项卡
|
||||
self.config_tab = QWidget()
|
||||
self.tab_widget.addTab(self.config_tab, "配置")
|
||||
def create_ui(self):
|
||||
# Create tabs
|
||||
self.tabs = QTabWidget()
|
||||
self.main_layout.addWidget(self.tabs)
|
||||
|
||||
# 创建数据收集选项卡
|
||||
self.collection_tab = QWidget()
|
||||
self.tab_widget.addTab(self.collection_tab, "数据收集")
|
||||
# General Settings Tab
|
||||
self.general_tab = QWidget()
|
||||
self.tabs.addTab(self.general_tab, "General Settings")
|
||||
self.create_general_tab()
|
||||
|
||||
self.setup_config_tab()
|
||||
self.setup_collection_tab()
|
||||
# Camera Settings Tab
|
||||
self.camera_tab = QWidget()
|
||||
self.tabs.addTab(self.camera_tab, "Camera Settings")
|
||||
self.create_camera_tab()
|
||||
|
||||
# 初始化数据收集线程
|
||||
self.collection_thread = None
|
||||
# Arm Settings Tab
|
||||
self.arm_tab = QWidget()
|
||||
self.tabs.addTab(self.arm_tab, "Arm Settings")
|
||||
self.create_arm_tab()
|
||||
|
||||
def setup_config_tab(self):
|
||||
config_layout = QVBoxLayout(self.config_tab)
|
||||
# Control Buttons
|
||||
self.control_group = QGroupBox("Control")
|
||||
control_layout = QHBoxLayout()
|
||||
|
||||
# 基本配置组
|
||||
basic_group = QGroupBox("基本配置")
|
||||
basic_form = QFormLayout()
|
||||
self.load_config_button = QPushButton("Load Config")
|
||||
self.save_config_button = QPushButton("Save Config")
|
||||
self.start_button = QPushButton("Start Recording")
|
||||
self.stop_button = QPushButton("Stop Recording")
|
||||
self.exit_button = QPushButton("Exit")
|
||||
|
||||
self.dataset_dir = QLineEdit("./data")
|
||||
self.browse_button = QPushButton("浏览")
|
||||
self.browse_button.clicked.connect(self.browse_dataset_dir)
|
||||
control_layout.addWidget(self.load_config_button)
|
||||
control_layout.addWidget(self.save_config_button)
|
||||
control_layout.addWidget(self.start_button)
|
||||
control_layout.addWidget(self.stop_button)
|
||||
control_layout.addWidget(self.exit_button)
|
||||
|
||||
dir_layout = QHBoxLayout()
|
||||
dir_layout.addWidget(self.dataset_dir)
|
||||
dir_layout.addWidget(self.browse_button)
|
||||
self.control_group.setLayout(control_layout)
|
||||
self.main_layout.addWidget(self.control_group)
|
||||
|
||||
self.task_name = QLineEdit("aloha_mobile_dummy")
|
||||
self.episode_idx = QSpinBox()
|
||||
self.episode_idx.setRange(0, 1000)
|
||||
self.max_timesteps = QSpinBox()
|
||||
self.max_timesteps.setRange(1, 10000)
|
||||
self.max_timesteps.setValue(500)
|
||||
self.num_episodes = QSpinBox()
|
||||
self.num_episodes.setRange(1, 100)
|
||||
self.num_episodes.setValue(1)
|
||||
self.frame_rate = QSpinBox()
|
||||
self.frame_rate.setRange(1, 120)
|
||||
self.frame_rate.setValue(30)
|
||||
def create_general_tab(self):
|
||||
layout = QVBoxLayout(self.general_tab)
|
||||
|
||||
basic_form.addRow("数据集目录:", dir_layout)
|
||||
basic_form.addRow("任务名称:", self.task_name)
|
||||
basic_form.addRow("起始集索引:", self.episode_idx)
|
||||
basic_form.addRow("最大时间步:", self.max_timesteps)
|
||||
basic_form.addRow("集数:", self.num_episodes)
|
||||
basic_form.addRow("帧率:", self.frame_rate)
|
||||
# Config File Path
|
||||
config_group = QGroupBox("Configuration File")
|
||||
config_layout = QHBoxLayout()
|
||||
|
||||
basic_group.setLayout(basic_form)
|
||||
config_layout.addWidget(basic_group)
|
||||
self.config_path_edit = QLineEdit(self.config_path)
|
||||
self.browse_config_button = QPushButton("Browse...")
|
||||
|
||||
# 相机话题组
|
||||
camera_group = QGroupBox("相机话题")
|
||||
camera_form = QFormLayout()
|
||||
config_layout.addWidget(QLabel("Config File:"))
|
||||
config_layout.addWidget(self.config_path_edit)
|
||||
config_layout.addWidget(self.browse_config_button)
|
||||
|
||||
self.img_front_topic = QLineEdit('/camera_f/color/image_raw')
|
||||
self.img_left_topic = QLineEdit('/camera_l/color/image_raw')
|
||||
self.img_right_topic = QLineEdit('/camera_r/color/image_raw')
|
||||
|
||||
self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw')
|
||||
self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw')
|
||||
self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw')
|
||||
|
||||
self.use_depth_image = QCheckBox("使用深度图像")
|
||||
|
||||
camera_form.addRow("前置相机:", self.img_front_topic)
|
||||
camera_form.addRow("左腕相机:", self.img_left_topic)
|
||||
camera_form.addRow("右腕相机:", self.img_right_topic)
|
||||
camera_form.addRow("前置深度:", self.img_front_depth_topic)
|
||||
camera_form.addRow("左腕深度:", self.img_left_depth_topic)
|
||||
camera_form.addRow("右腕深度:", self.img_right_depth_topic)
|
||||
camera_form.addRow("", self.use_depth_image)
|
||||
|
||||
camera_group.setLayout(camera_form)
|
||||
config_layout.addWidget(camera_group)
|
||||
|
||||
# 机器人话题组
|
||||
robot_group = QGroupBox("机器人话题")
|
||||
robot_form = QFormLayout()
|
||||
|
||||
self.master_arm_left_topic = QLineEdit('/master/joint_left')
|
||||
self.master_arm_right_topic = QLineEdit('/master/joint_right')
|
||||
self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left')
|
||||
self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right')
|
||||
self.robot_base_topic = QLineEdit('/odom')
|
||||
self.use_robot_base = QCheckBox("使用机器人底盘")
|
||||
|
||||
robot_form.addRow("主左臂:", self.master_arm_left_topic)
|
||||
robot_form.addRow("主右臂:", self.master_arm_right_topic)
|
||||
robot_form.addRow("从左臂:", self.puppet_arm_left_topic)
|
||||
robot_form.addRow("从右臂:", self.puppet_arm_right_topic)
|
||||
robot_form.addRow("底盘:", self.robot_base_topic)
|
||||
robot_form.addRow("", self.use_robot_base)
|
||||
|
||||
robot_group.setLayout(robot_form)
|
||||
config_layout.addWidget(robot_group)
|
||||
|
||||
# 相机名称配置
|
||||
camera_names_group = QGroupBox("相机名称")
|
||||
camera_names_layout = QVBoxLayout()
|
||||
|
||||
self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
||||
self.camera_checkboxes = {}
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
self.camera_checkboxes[cam_name] = QCheckBox(cam_name)
|
||||
self.camera_checkboxes[cam_name].setChecked(True)
|
||||
camera_names_layout.addWidget(self.camera_checkboxes[cam_name])
|
||||
|
||||
camera_names_group.setLayout(camera_names_layout)
|
||||
config_layout.addWidget(camera_names_group)
|
||||
|
||||
# 保存配置按钮
|
||||
self.save_config_button = QPushButton("保存配置")
|
||||
self.save_config_button.clicked.connect(self.save_config)
|
||||
config_layout.addWidget(self.save_config_button)
|
||||
|
||||
def setup_collection_tab(self):
|
||||
collection_layout = QVBoxLayout(self.collection_tab)
|
||||
|
||||
# 当前配置展示
|
||||
config_group = QGroupBox("当前配置")
|
||||
self.config_text = QTextEdit()
|
||||
self.config_text.setReadOnly(True)
|
||||
config_layout = QVBoxLayout()
|
||||
config_layout.addWidget(self.config_text)
|
||||
config_group.setLayout(config_layout)
|
||||
collection_layout.addWidget(config_group)
|
||||
layout.addWidget(config_group)
|
||||
|
||||
# 操作按钮
|
||||
buttons_layout = QHBoxLayout()
|
||||
# Dataset Directory
|
||||
dir_group = QGroupBox("Dataset Directory")
|
||||
dir_layout = QHBoxLayout()
|
||||
|
||||
self.start_button = QPushButton("开始收集")
|
||||
self.start_button.setIcon(QIcon.fromTheme("media-playback-start"))
|
||||
self.start_button.clicked.connect(self.start_collection)
|
||||
self.dataset_dir_edit = QLineEdit()
|
||||
self.browse_dir_button = QPushButton("Browse...")
|
||||
|
||||
self.stop_button = QPushButton("停止")
|
||||
self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop"))
|
||||
self.stop_button.setEnabled(False)
|
||||
self.stop_button.clicked.connect(self.stop_collection)
|
||||
dir_layout.addWidget(QLabel("Dataset Directory:"))
|
||||
dir_layout.addWidget(self.dataset_dir_edit)
|
||||
dir_layout.addWidget(self.browse_dir_button)
|
||||
|
||||
buttons_layout.addWidget(self.start_button)
|
||||
buttons_layout.addWidget(self.stop_button)
|
||||
collection_layout.addLayout(buttons_layout)
|
||||
dir_group.setLayout(dir_layout)
|
||||
layout.addWidget(dir_group)
|
||||
|
||||
# 进度条
|
||||
self.progress_bar = QProgressBar()
|
||||
self.progress_bar.setValue(0)
|
||||
collection_layout.addWidget(self.progress_bar)
|
||||
# Task Settings
|
||||
task_group = QGroupBox("Task Settings")
|
||||
task_layout = QVBoxLayout()
|
||||
|
||||
# 日志输出
|
||||
log_group = QGroupBox("操作日志")
|
||||
self.log_text = QTextEdit()
|
||||
self.log_text.setReadOnly(True)
|
||||
log_layout = QVBoxLayout()
|
||||
log_layout.addWidget(self.log_text)
|
||||
log_group.setLayout(log_layout)
|
||||
collection_layout.addWidget(log_group)
|
||||
self.task_name_edit = QLineEdit()
|
||||
self.episode_idx_spin = QSpinBox()
|
||||
self.episode_idx_spin.setRange(0, 9999)
|
||||
self.max_timesteps_spin = QSpinBox()
|
||||
self.max_timesteps_spin.setRange(1, 10000)
|
||||
self.num_episodes_spin = QSpinBox()
|
||||
self.num_episodes_spin.setRange(1, 1000)
|
||||
self.frame_rate_spin = QSpinBox()
|
||||
self.frame_rate_spin.setRange(1, 60)
|
||||
|
||||
task_layout.addWidget(QLabel("Task Name:"))
|
||||
task_layout.addWidget(self.task_name_edit)
|
||||
task_layout.addWidget(QLabel("Episode Index:"))
|
||||
task_layout.addWidget(self.episode_idx_spin)
|
||||
task_layout.addWidget(QLabel("Max Timesteps:"))
|
||||
task_layout.addWidget(self.max_timesteps_spin)
|
||||
task_layout.addWidget(QLabel("Number of Episodes:"))
|
||||
task_layout.addWidget(self.num_episodes_spin)
|
||||
task_layout.addWidget(QLabel("Frame Rate:"))
|
||||
task_layout.addWidget(self.frame_rate_spin)
|
||||
|
||||
task_group.setLayout(task_layout)
|
||||
layout.addWidget(task_group)
|
||||
|
||||
# Options
|
||||
options_group = QGroupBox("Options")
|
||||
options_layout = QVBoxLayout()
|
||||
|
||||
self.use_robot_base_check = QCheckBox("Use Robot Base")
|
||||
self.use_depth_image_check = QCheckBox("Use Depth Image")
|
||||
|
||||
options_layout.addWidget(self.use_robot_base_check)
|
||||
options_layout.addWidget(self.use_depth_image_check)
|
||||
|
||||
options_group.setLayout(options_layout)
|
||||
layout.addWidget(options_group)
|
||||
|
||||
layout.addStretch()
|
||||
|
||||
def create_camera_tab(self):
|
||||
layout = QVBoxLayout(self.camera_tab)
|
||||
|
||||
# Color Image Topics
|
||||
color_group = QGroupBox("Color Image Topics")
|
||||
color_layout = QVBoxLayout()
|
||||
|
||||
self.img_front_topic_edit = QLineEdit()
|
||||
self.img_left_topic_edit = QLineEdit()
|
||||
self.img_right_topic_edit = QLineEdit()
|
||||
|
||||
color_layout.addWidget(QLabel("Front Camera Topic:"))
|
||||
color_layout.addWidget(self.img_front_topic_edit)
|
||||
color_layout.addWidget(QLabel("Left Camera Topic:"))
|
||||
color_layout.addWidget(self.img_left_topic_edit)
|
||||
color_layout.addWidget(QLabel("Right Camera Topic:"))
|
||||
color_layout.addWidget(self.img_right_topic_edit)
|
||||
|
||||
color_group.setLayout(color_layout)
|
||||
layout.addWidget(color_group)
|
||||
|
||||
# Depth Image Topics
|
||||
depth_group = QGroupBox("Depth Image Topics")
|
||||
depth_layout = QVBoxLayout()
|
||||
|
||||
self.img_front_depth_topic_edit = QLineEdit()
|
||||
self.img_left_depth_topic_edit = QLineEdit()
|
||||
self.img_right_depth_topic_edit = QLineEdit()
|
||||
|
||||
depth_layout.addWidget(QLabel("Front Depth Topic:"))
|
||||
depth_layout.addWidget(self.img_front_depth_topic_edit)
|
||||
depth_layout.addWidget(QLabel("Left Depth Topic:"))
|
||||
depth_layout.addWidget(self.img_left_depth_topic_edit)
|
||||
depth_layout.addWidget(QLabel("Right Depth Topic:"))
|
||||
depth_layout.addWidget(self.img_right_depth_topic_edit)
|
||||
|
||||
depth_group.setLayout(depth_layout)
|
||||
layout.addWidget(depth_group)
|
||||
|
||||
layout.addStretch()
|
||||
|
||||
def create_arm_tab(self):
|
||||
layout = QVBoxLayout(self.arm_tab)
|
||||
|
||||
# Master Arm Topics
|
||||
master_group = QGroupBox("Master Arm Topics")
|
||||
master_layout = QVBoxLayout()
|
||||
|
||||
self.master_arm_left_topic_edit = QLineEdit()
|
||||
self.master_arm_right_topic_edit = QLineEdit()
|
||||
|
||||
master_layout.addWidget(QLabel("Master Left Arm Topic:"))
|
||||
master_layout.addWidget(self.master_arm_left_topic_edit)
|
||||
master_layout.addWidget(QLabel("Master Right Arm Topic:"))
|
||||
master_layout.addWidget(self.master_arm_right_topic_edit)
|
||||
|
||||
master_group.setLayout(master_layout)
|
||||
layout.addWidget(master_group)
|
||||
|
||||
# Puppet Arm Topics
|
||||
puppet_group = QGroupBox("Puppet Arm Topics")
|
||||
puppet_layout = QVBoxLayout()
|
||||
|
||||
self.puppet_arm_left_topic_edit = QLineEdit()
|
||||
self.puppet_arm_right_topic_edit = QLineEdit()
|
||||
|
||||
puppet_layout.addWidget(QLabel("Puppet Left Arm Topic:"))
|
||||
puppet_layout.addWidget(self.puppet_arm_left_topic_edit)
|
||||
puppet_layout.addWidget(QLabel("Puppet Right Arm Topic:"))
|
||||
puppet_layout.addWidget(self.puppet_arm_right_topic_edit)
|
||||
|
||||
puppet_group.setLayout(puppet_layout)
|
||||
layout.addWidget(puppet_group)
|
||||
|
||||
# Robot Base Topic
|
||||
base_group = QGroupBox("Robot Base Topic")
|
||||
base_layout = QVBoxLayout()
|
||||
|
||||
self.robot_base_topic_edit = QLineEdit()
|
||||
|
||||
base_layout.addWidget(QLabel("Robot Base Topic:"))
|
||||
base_layout.addWidget(self.robot_base_topic_edit)
|
||||
|
||||
base_group.setLayout(base_layout)
|
||||
layout.addWidget(base_group)
|
||||
|
||||
layout.addStretch()
|
||||
|
||||
def setup_connections(self):
|
||||
self.load_config_button.clicked.connect(self.load_config)
|
||||
self.save_config_button.clicked.connect(self.save_config)
|
||||
self.browse_config_button.clicked.connect(self.browse_config_file)
|
||||
self.browse_dir_button.clicked.connect(self.browse_dataset_dir)
|
||||
self.start_button.clicked.connect(self.start_recording)
|
||||
self.stop_button.clicked.connect(self.stop_recording)
|
||||
self.exit_button.clicked.connect(self.close)
|
||||
|
||||
def load_default_config(self):
|
||||
default_config = {
|
||||
'dataset_dir': '/home/ubuntu/LYT/lerobot_aloha/datasets/3camera',
|
||||
'task_name': 'aloha_mobile_dummy',
|
||||
'episode_idx': 0,
|
||||
'max_timesteps': 500,
|
||||
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
|
||||
'num_episodes': 50,
|
||||
'img_front_topic': '/camera_f/color/image_raw',
|
||||
'img_left_topic': '/camera_l/color/image_raw',
|
||||
'img_right_topic': '/camera_r/color/image_raw',
|
||||
'img_front_depth_topic': '/camera_f/depth/image_raw',
|
||||
'img_left_depth_topic': '/camera_l/depth/image_raw',
|
||||
'img_right_depth_topic': '/camera_r/depth/image_raw',
|
||||
'master_arm_left_topic': '/master/joint_left',
|
||||
'master_arm_right_topic': '/master/joint_right',
|
||||
'puppet_arm_left_topic': '/puppet/joint_left',
|
||||
'puppet_arm_right_topic': '/puppet/joint_right',
|
||||
'robot_base_topic': '/odom',
|
||||
'use_robot_base': False,
|
||||
'use_depth_image': False,
|
||||
'frame_rate': 30
|
||||
}
|
||||
|
||||
# Update UI with default values
|
||||
self.update_ui_from_config(default_config)
|
||||
|
||||
def update_ui_from_config(self, config):
|
||||
"""Update UI elements from a config dictionary"""
|
||||
self.dataset_dir_edit.setText(config.get('dataset_dir', ''))
|
||||
self.task_name_edit.setText(config.get('task_name', ''))
|
||||
self.episode_idx_spin.setValue(config.get('episode_idx', 0))
|
||||
self.max_timesteps_spin.setValue(config.get('max_timesteps', 500))
|
||||
self.num_episodes_spin.setValue(config.get('num_episodes', 1))
|
||||
self.frame_rate_spin.setValue(config.get('frame_rate', 30))
|
||||
|
||||
self.img_front_topic_edit.setText(config.get('img_front_topic', ''))
|
||||
self.img_left_topic_edit.setText(config.get('img_left_topic', ''))
|
||||
self.img_right_topic_edit.setText(config.get('img_right_topic', ''))
|
||||
|
||||
self.img_front_depth_topic_edit.setText(config.get('img_front_depth_topic', ''))
|
||||
self.img_left_depth_topic_edit.setText(config.get('img_left_depth_topic', ''))
|
||||
self.img_right_depth_topic_edit.setText(config.get('img_right_depth_topic', ''))
|
||||
|
||||
self.master_arm_left_topic_edit.setText(config.get('master_arm_left_topic', ''))
|
||||
self.master_arm_right_topic_edit.setText(config.get('master_arm_right_topic', ''))
|
||||
self.puppet_arm_left_topic_edit.setText(config.get('puppet_arm_left_topic', ''))
|
||||
self.puppet_arm_right_topic_edit.setText(config.get('puppet_arm_right_topic', ''))
|
||||
|
||||
self.robot_base_topic_edit.setText(config.get('robot_base_topic', ''))
|
||||
self.use_robot_base_check.setChecked(config.get('use_robot_base', False))
|
||||
self.use_depth_image_check.setChecked(config.get('use_depth_image', False))
|
||||
|
||||
def get_config_from_ui(self):
|
||||
"""Get current UI values as a config dictionary"""
|
||||
config = {
|
||||
'dataset_dir': self.dataset_dir_edit.text(),
|
||||
'task_name': self.task_name_edit.text(),
|
||||
'episode_idx': self.episode_idx_spin.value(),
|
||||
'max_timesteps': self.max_timesteps_spin.value(),
|
||||
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
|
||||
'num_episodes': self.num_episodes_spin.value(),
|
||||
'img_front_topic': self.img_front_topic_edit.text(),
|
||||
'img_left_topic': self.img_left_topic_edit.text(),
|
||||
'img_right_topic': self.img_right_topic_edit.text(),
|
||||
'img_front_depth_topic': self.img_front_depth_topic_edit.text(),
|
||||
'img_left_depth_topic': self.img_left_depth_topic_edit.text(),
|
||||
'img_right_depth_topic': self.img_right_depth_topic_edit.text(),
|
||||
'master_arm_left_topic': self.master_arm_left_topic_edit.text(),
|
||||
'master_arm_right_topic': self.master_arm_right_topic_edit.text(),
|
||||
'puppet_arm_left_topic': self.puppet_arm_left_topic_edit.text(),
|
||||
'puppet_arm_right_topic': self.puppet_arm_right_topic_edit.text(),
|
||||
'robot_base_topic': self.robot_base_topic_edit.text(),
|
||||
'use_robot_base': self.use_robot_base_check.isChecked(),
|
||||
'use_depth_image': self.use_depth_image_check.isChecked(),
|
||||
'frame_rate': self.frame_rate_spin.value()
|
||||
}
|
||||
return config
|
||||
|
||||
def browse_config_file(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(
|
||||
self, "Select Config File", "", "YAML Files (*.yaml *.yml)"
|
||||
)
|
||||
if file_path:
|
||||
self.config_path_edit.setText(file_path)
|
||||
self.load_config()
|
||||
|
||||
def browse_dataset_dir(self):
|
||||
directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text())
|
||||
if directory:
|
||||
self.dataset_dir.setText(directory)
|
||||
|
||||
def save_config(self):
|
||||
# 更新配置显示
|
||||
selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()]
|
||||
|
||||
config_text = f"""
|
||||
任务名称: {self.task_name.text()}
|
||||
数据集目录: {self.dataset_dir.text()}
|
||||
起始集索引: {self.episode_idx.value()}
|
||||
最大时间步: {self.max_timesteps.value()}
|
||||
集数: {self.num_episodes.value()}
|
||||
帧率: {self.frame_rate.value()}
|
||||
使用深度图像: {"是" if self.use_depth_image.isChecked() else "否"}
|
||||
使用机器人底盘: {"是" if self.use_robot_base.isChecked() else "否"}
|
||||
相机: {', '.join(selected_cameras)}
|
||||
"""
|
||||
|
||||
self.config_text.setText(config_text)
|
||||
self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡
|
||||
|
||||
QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集")
|
||||
|
||||
def start_collection(self):
|
||||
if not self.task_name.text():
|
||||
QMessageBox.warning(self, "配置错误", "请输入有效的任务名称")
|
||||
dir_path = QFileDialog.getExistingDirectory(
|
||||
self, "Select Dataset Directory"
|
||||
)
|
||||
if dir_path:
|
||||
self.dataset_dir_edit.setText(dir_path)
|
||||
|
||||
def load_config(self):
|
||||
config_path = self.config_path_edit.text()
|
||||
if not os.path.exists(config_path):
|
||||
QMessageBox.warning(self, "Warning", f"Config file not found: {config_path}")
|
||||
return
|
||||
|
||||
# 构建参数
|
||||
args = argparse.Namespace(
|
||||
dataset_dir=self.dataset_dir.text(),
|
||||
task_name=self.task_name.text(),
|
||||
episode_idx=self.episode_idx.value(),
|
||||
max_timesteps=self.max_timesteps.value(),
|
||||
num_episodes=self.num_episodes.value(),
|
||||
camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()],
|
||||
img_front_topic=self.img_front_topic.text(),
|
||||
img_left_topic=self.img_left_topic.text(),
|
||||
img_right_topic=self.img_right_topic.text(),
|
||||
img_front_depth_topic=self.img_front_depth_topic.text(),
|
||||
img_left_depth_topic=self.img_left_depth_topic.text(),
|
||||
img_right_depth_topic=self.img_right_depth_topic.text(),
|
||||
master_arm_left_topic=self.master_arm_left_topic.text(),
|
||||
master_arm_right_topic=self.master_arm_right_topic.text(),
|
||||
puppet_arm_left_topic=self.puppet_arm_left_topic.text(),
|
||||
puppet_arm_right_topic=self.puppet_arm_right_topic.text(),
|
||||
robot_base_topic=self.robot_base_topic.text(),
|
||||
use_robot_base=self.use_robot_base.isChecked(),
|
||||
use_depth_image=self.use_depth_image.isChecked(),
|
||||
frame_rate=self.frame_rate.value()
|
||||
)
|
||||
|
||||
# 更新UI状态
|
||||
self.start_button.setEnabled(False)
|
||||
self.stop_button.setEnabled(True)
|
||||
self.progress_bar.setValue(0)
|
||||
self.log_text.clear()
|
||||
self.log_text.append("正在初始化数据收集...\n")
|
||||
|
||||
# 创建并启动线程
|
||||
self.collection_thread = DataCollectionThread(args)
|
||||
self.collection_thread.update_signal.connect(self.update_log)
|
||||
self.collection_thread.progress_signal.connect(self.update_progress)
|
||||
self.collection_thread.finish_signal.connect(self.collection_finished)
|
||||
self.collection_thread.start()
|
||||
|
||||
def stop_collection(self):
|
||||
if self.collection_thread and self.collection_thread.isRunning():
|
||||
self.log_text.append("正在停止数据收集...\n")
|
||||
self.collection_thread.stop()
|
||||
self.collection_thread = None
|
||||
try:
|
||||
with open(config_path, 'r') as f:
|
||||
config = yaml.safe_load(f)
|
||||
self.update_ui_from_config(config)
|
||||
self.statusBar().showMessage(f"Config loaded from {config_path}", 3000)
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "Error", f"Failed to load config: {str(e)}")
|
||||
|
||||
self.start_button.setEnabled(True)
|
||||
self.stop_button.setEnabled(False)
|
||||
|
||||
@pyqtSlot(str)
|
||||
def update_log(self, message):
|
||||
self.log_text.append(message)
|
||||
# 自动滚动到底部
|
||||
cursor = self.log_text.textCursor()
|
||||
cursor.movePosition(QTextCursor.End)
|
||||
self.log_text.setTextCursor(cursor)
|
||||
|
||||
@pyqtSlot(int)
|
||||
def update_progress(self, value):
|
||||
self.progress_bar.setValue(value)
|
||||
|
||||
@pyqtSlot(bool, str)
|
||||
def collection_finished(self, success, message):
|
||||
self.start_button.setEnabled(True)
|
||||
self.stop_button.setEnabled(False)
|
||||
|
||||
if success:
|
||||
QMessageBox.information(self, "完成", message)
|
||||
else:
|
||||
QMessageBox.warning(self, "出错", f"数据收集失败: {message}")
|
||||
def save_config(self):
|
||||
config_path = self.config_path_edit.text()
|
||||
if not config_path:
|
||||
QMessageBox.warning(self, "Warning", "Please specify a config file path")
|
||||
return
|
||||
|
||||
# 更新episode_idx值
|
||||
if success and self.num_episodes.value() > 0:
|
||||
self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value())
|
||||
try:
|
||||
config = self.get_config_from_ui()
|
||||
with open(config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
self.statusBar().showMessage(f"Config saved to {config_path}", 3000)
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "Error", f"Failed to save config: {str(e)}")
|
||||
|
||||
def start_recording(self):
|
||||
try:
|
||||
# Save current config to a temporary file
|
||||
temp_config_path = "/tmp/aloha_temp_config.yaml"
|
||||
config = self.get_config_from_ui()
|
||||
|
||||
# Validate inputs
|
||||
if not config['dataset_dir']:
|
||||
QMessageBox.warning(self, "Warning", "Dataset directory cannot be empty!")
|
||||
return
|
||||
|
||||
if not config['task_name']:
|
||||
QMessageBox.warning(self, "Warning", "Task name cannot be empty!")
|
||||
return
|
||||
|
||||
with open(temp_config_path, 'w') as f:
|
||||
yaml.dump(config, f, default_flow_style=False)
|
||||
|
||||
self.statusBar().showMessage("Recording started...")
|
||||
|
||||
# Start recording with the temporary config file
|
||||
exit_code = main(temp_config_path)
|
||||
|
||||
if exit_code == 0:
|
||||
self.statusBar().showMessage("Recording completed successfully!", 5000)
|
||||
else:
|
||||
self.statusBar().showMessage("Recording completed with errors", 5000)
|
||||
|
||||
except Exception as e:
|
||||
QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}")
|
||||
self.statusBar().showMessage("Recording failed", 5000)
|
||||
|
||||
def stop_recording(self):
|
||||
# In a real application, this would signal the recording thread to stop
|
||||
self.statusBar().showMessage("Recording stopped", 5000)
|
||||
QMessageBox.information(self, "Info", "Stop recording requested. This would stop the recording in a real implementation.")
|
||||
|
||||
def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
window = AlohaDataCollectionGUI()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user