使用原生的数据搜集代码

This commit is contained in:
2025-04-13 21:41:45 +08:00
parent 3df284ddd1
commit a4fe5ee09a
20 changed files with 1477 additions and 1544 deletions

View File

@@ -1,3 +0,0 @@
python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data
python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1

View File

@@ -1,146 +0,0 @@
robot_type: aloha_agilex
ros_node_name: record_episodes
cameras:
cam_front:
img_topic_name: /camera_f/color/image_raw
depth_topic_name: /camera_f/depth/image_raw
width: 480
height: 640
rgb_shape: [480, 640, 3]
cam_left:
img_topic_name: /camera_l/color/image_raw
depth_topic_name: /camera_l/depth/image_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
cam_right:
img_topic_name: /camera_r/color/image_raw
depth_topic_name: /camera_r/depth/image_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
cam_high:
img_topic_name: /camera/color/image_raw
depth_topic_name: /camera/depth/image_rect_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
arm:
master_left:
topic_name: /master/joint_left
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none"
]
master_right:
topic_name: /master/joint_right
motors: [
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
puppet_left:
topic_name: /puppet/joint_left
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none"
]
puppet_right:
topic_name: /puppet/joint_right
motors: [
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
# follow the joint name in ros
state:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
velocity:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
effort:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
action:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]

View File

@@ -0,0 +1,305 @@
import cv2
import numpy as np
import dm_env
import collections
from collections import deque
import rospy
from sensor_msgs.msg import JointState
from sensor_msgs.msg import Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
from utils import display_camera_grid, save_data
class AlohaRobotRos:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.master_arm_right_deque = None
self.master_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.args = args
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.master_arm_left_deque = deque()
self.master_arm_right_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
def get_frame(self):
print(len(self.img_left_deque), len(self.img_right_deque), len(self.img_front_deque),
len(self.img_left_depth_deque), len(self.img_right_depth_deque), len(self.img_front_depth_deque))
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.master_arm_left_deque) == 0 or self.master_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.master_arm_right_deque) == 0 or self.master_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
# print("img_left:", img_left.shape)
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.master_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.master_arm_left_deque.popleft()
master_arm_left = self.master_arm_left_deque.popleft()
while self.master_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.master_arm_right_deque.popleft()
master_arm_right = self.master_arm_right_deque.popleft()
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_left_depth = cv2.copyMakeBorder(img_left_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_right_depth = cv2.copyMakeBorder(img_right_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_front_depth = cv2.copyMakeBorder(img_front_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
# import pdb
# pdb.set_trace()
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def master_arm_left_callback(self, msg):
if len(self.master_arm_left_deque) >= 2000:
self.master_arm_left_deque.popleft()
self.master_arm_left_deque.append(msg)
def master_arm_right_callback(self, msg):
if len(self.master_arm_right_deque) >= 2000:
self.master_arm_right_deque.popleft()
self.master_arm_right_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def init_ros(self):
rospy.init_node('record_episodes', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.master_arm_left_topic, JointState, self.master_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.master_arm_right_topic, JointState, self.master_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
def process(self):
timesteps = []
actions = []
# 图像数据
image = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8)
image_dict = dict()
for cam_name in self.args.camera_names:
image_dict[cam_name] = image
count = 0
# input_key = input("please input s:")
# while input_key != 's' and not rospy.is_shutdown():
# input_key = input("please input s:")
rate = rospy.Rate(self.args.frame_rate)
print_flag = True
while (count < self.args.max_timesteps + 1) and not rospy.is_shutdown():
# 2 收集数据
result = self.get_frame()
# import pdb
# pdb.set_trace()
if not result:
if print_flag:
print("syn fail\n")
print_flag = False
rate.sleep()
continue
print_flag = True
count += 1
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) = result
# 2.1 图像信息
image_dict = dict()
image_dict[self.args.camera_names[0]] = img_front
image_dict[self.args.camera_names[1]] = img_left
image_dict[self.args.camera_names[2]] = img_right
# import pdb
# pdb.set_trace()
display_camera_grid(image_dict)
# 2.2 从臂的信息从臂的状态 机械臂示教模式时 会自动订阅
obs = collections.OrderedDict() # 有序的字典
obs['images'] = image_dict
if self.args.use_depth_image:
image_dict_depth = dict()
image_dict_depth[self.args.camera_names[0]] = img_front_depth
image_dict_depth[self.args.camera_names[1]] = img_left_depth
image_dict_depth[self.args.camera_names[2]] = img_right_depth
obs['images_depth'] = image_dict_depth
obs['qpos'] = np.concatenate((np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
obs['qvel'] = np.concatenate((np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
obs['effort'] = np.concatenate((np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
if self.args.use_robot_base:
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
else:
obs['base_vel'] = [0.0, 0.0]
# 第一帧 只包含first fisrt只保存StepType.FIRST
if count == 1:
ts = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=obs)
timesteps.append(ts)
continue
# 时间步
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=None,
discount=None,
observation=obs)
# 主臂保存状态
action = np.concatenate((np.array(master_arm_left.position), np.array(master_arm_right.position)), axis=0)
actions.append(action)
timesteps.append(ts)
print(f"\n{self.args.episode_idx} | Frame data: {count}\r", end="")
if rospy.is_shutdown():
exit(-1)
rate.sleep()
print("len(timesteps): ", len(timesteps))
print("len(actions) : ", len(actions))
return timesteps, actions

166
collect_data/collect_data.py Executable file
View File

@@ -0,0 +1,166 @@
import os
import time
import argparse
from aloha_mobile import AlohaRobotRos
from utils import save_data, init_keyboard_listener
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset_dir.',
default="./data", required=False)
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
default="aloha_mobile_dummy", required=False)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',
default=0, required=False)
parser.add_argument('--max_timesteps', action='store', type=int, help='Max_timesteps.',
default=500, required=False)
parser.add_argument('--camera_names', action='store', type=str, help='camera_names',
default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False)
parser.add_argument('--num_episodes', action='store', type=int, help='Num_episodes.',
default=1, required=False)
# topic name of color image
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
# topic name of depth image
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
default='/camera_f/depth/image_raw', required=False)
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
default='/camera_l/depth/image_raw', required=False)
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
default='/camera_r/depth/image_raw', required=False)
# topic name of arm
parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
default='/master/joint_left', required=False)
parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
default='/master/joint_right', required=False)
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
# topic name of robot_base
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/odom', required=False)
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
default=False, required=False)
# collect depth image
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
default=False, required=False)
parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate',
default=30, required=False)
args = parser.parse_args()
return args
def main():
args = get_arguments()
ros_operator = AlohaRobotRos(args)
dataset_dir = os.path.join(args.dataset_dir, args.task_name)
# 确保数据集目录存在
os.makedirs(dataset_dir, exist_ok=True)
# 单集收集模式
if args.num_episodes == 1:
print(f"Recording single episode {args.episode_idx}...")
timesteps, actions = ros_operator.process()
if len(actions) < args.max_timesteps:
print(f"\033[31m\nSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m\n")
return -1
dataset_path = os.path.join(dataset_dir, f"episode_{args.episode_idx}")
save_data(args, timesteps, actions, dataset_path)
print(f"\033[32mEpisode {args.episode_idx} saved successfully at {dataset_path}\033[0m")
return 0
# 多集收集模式
print("""
\033[1;36mKeyboard Controls:\033[0m
\033[1mLeft Arrow\033[0m: Start Recording
\033[1mRight Arrow\033[0m: Save Current Data
\033[1mDown Arrow\033[0m: Discard Current Data
\033[1mUp Arrow\033[0m: Replay Data (if implemented)
\033[1mESC\033[0m: Exit Program
""")
# 初始化键盘监听器
listener, events = init_keyboard_listener()
episode_idx = args.episode_idx
collected_episodes = 0
try:
while collected_episodes < args.num_episodes:
if events["exit_early"]:
print("\033[33mOperation terminated by user\033[0m")
break
if events["record_start"]:
# 重置事件状态,开始新的录制
events["record_start"] = False
events["save_data"] = False
events["discard_data"] = False
print(f"\n\033[1;32mRecording episode {episode_idx}...\033[0m")
timesteps, actions = ros_operator.process()
print(f"\033[1;33mRecorded {len(actions)} timesteps. (→ to save, ↓ to discard)\033[0m")
# 等待用户决定保存或丢弃
while True:
if events["save_data"]:
events["save_data"] = False
if len(actions) < args.max_timesteps:
print(f"\033[31mSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m")
else:
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
save_data(args, timesteps, actions, dataset_path)
print(f"\033[32mEpisode {episode_idx} saved successfully at {dataset_path}\033[0m")
episode_idx += 1
collected_episodes += 1
print(f"\033[1mProgress: {collected_episodes}/{args.num_episodes} episodes collected. (← to start new episode)\033[0m")
break
if events["discard_data"]:
events["discard_data"] = False
print("\033[33mData discarded. Press ← to start a new recording.\033[0m")
break
if events["exit_early"]:
print("\033[33mOperation terminated by user\033[0m")
return 0
time.sleep(0.1) # 减少CPU使用率
time.sleep(0.1) # 减少CPU使用率
if collected_episodes == args.num_episodes:
print(f"\n\033[1;32mData collection complete! All {args.num_episodes} episodes collected.\033[0m")
finally:
# 确保监听器被清理
if listener:
listener.stop()
print("Keyboard listener stopped")
if __name__ == '__main__':
try:
exit_code = main()
exit(exit_code if exit_code is not None else 0)
except KeyboardInterrupt:
print("\n\033[33mProgram interrupted by user\033[0m")
exit(0)
except Exception as e:
print(f"\n\033[31mError: {e}\033[0m")
# python collect_data.py --dataset_dir ~/data --max_timesteps 500 --episode_idx 0

View File

@@ -0,0 +1,416 @@
import os
import sys
import time
import argparse
from aloha_mobile import AlohaRobotRos
from utils import save_data, init_keyboard_listener
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox,
QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog,
QMessageBox, QProgressBar, QComboBox)
from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot
from PyQt5.QtGui import QFont, QIcon, QTextCursor
class DataCollectionThread(QThread):
"""处理数据收集的线程"""
update_signal = pyqtSignal(str)
progress_signal = pyqtSignal(int)
finish_signal = pyqtSignal(bool, str)
def __init__(self, args, parent=None):
super(DataCollectionThread, self).__init__(parent)
self.args = args
self.is_running = True
self.ros_operator = None
def run(self):
try:
self.update_signal.emit("正在初始化ROS操作...\n")
self.ros_operator = AlohaRobotRos(self.args)
dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name)
os.makedirs(dataset_dir, exist_ok=True)
# 单集收集模式
if self.args.num_episodes == 1:
self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n")
timesteps, actions = self.ros_operator.process()
if len(actions) < self.args.max_timesteps:
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步")
return
dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}")
save_data(self.args, timesteps, actions, dataset_path)
self.update_signal.emit(f"{self.args.episode_idx} 集成功保存到 {dataset_path}.\n")
self.finish_signal.emit(True, "数据收集完成")
# 多集收集模式
else:
self.update_signal.emit("""
键盘控制:
← 左箭头: 开始录制
→ 右箭头: 保存当前数据
↓ 下箭头: 丢弃当前数据
ESC: 退出程序
""")
# 初始化键盘监听器
listener, events = init_keyboard_listener()
episode_idx = self.args.episode_idx
collected_episodes = 0
try:
while collected_episodes < self.args.num_episodes and self.is_running:
if events["exit_early"]:
self.update_signal.emit("操作被用户终止.\n")
break
if events["record_start"]:
# 重置事件状态,开始新的录制
events["record_start"] = False
events["save_data"] = False
events["discard_data"] = False
self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n")
timesteps, actions = self.ros_operator.process()
self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n")
# 等待用户决定保存或丢弃
while self.is_running:
if events["save_data"]:
events["save_data"] = False
if len(actions) < self.args.max_timesteps:
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
else:
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
save_data(self.args, timesteps, actions, dataset_path)
self.update_signal.emit(f"{episode_idx} 集成功保存到 {dataset_path}.\n")
episode_idx += 1
collected_episodes += 1
progress_percentage = int(collected_episodes * 100 / self.args.num_episodes)
self.progress_signal.emit(progress_percentage)
self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n")
break
if events["discard_data"]:
events["discard_data"] = False
self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n")
break
if events["exit_early"]:
self.update_signal.emit("操作被用户终止.\n")
self.is_running = False
break
time.sleep(0.1) # 减少CPU使用率
time.sleep(0.1) # 减少CPU使用率
if collected_episodes == self.args.num_episodes:
self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n")
self.finish_signal.emit(True, "全部数据集收集完成")
else:
self.finish_signal.emit(False, "数据收集未完成")
finally:
# 确保监听器被清理
if listener:
listener.stop()
self.update_signal.emit("键盘监听器已停止\n")
except Exception as e:
self.update_signal.emit(f"错误: {str(e)}\n")
self.finish_signal.emit(False, str(e))
def stop(self):
self.is_running = False
self.wait()
class AlohaDataCollectionGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("ALOHA 数据收集工具")
self.setGeometry(100, 100, 800, 700)
# 主组件
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
self.main_layout = QVBoxLayout(self.central_widget)
# 创建选项卡
self.tab_widget = QTabWidget()
self.main_layout.addWidget(self.tab_widget)
# 创建配置选项卡
self.config_tab = QWidget()
self.tab_widget.addTab(self.config_tab, "配置")
# 创建数据收集选项卡
self.collection_tab = QWidget()
self.tab_widget.addTab(self.collection_tab, "数据收集")
self.setup_config_tab()
self.setup_collection_tab()
# 初始化数据收集线程
self.collection_thread = None
def setup_config_tab(self):
config_layout = QVBoxLayout(self.config_tab)
# 基本配置组
basic_group = QGroupBox("基本配置")
basic_form = QFormLayout()
self.dataset_dir = QLineEdit("./data")
self.browse_button = QPushButton("浏览")
self.browse_button.clicked.connect(self.browse_dataset_dir)
dir_layout = QHBoxLayout()
dir_layout.addWidget(self.dataset_dir)
dir_layout.addWidget(self.browse_button)
self.task_name = QLineEdit("aloha_mobile_dummy")
self.episode_idx = QSpinBox()
self.episode_idx.setRange(0, 1000)
self.max_timesteps = QSpinBox()
self.max_timesteps.setRange(1, 10000)
self.max_timesteps.setValue(500)
self.num_episodes = QSpinBox()
self.num_episodes.setRange(1, 100)
self.num_episodes.setValue(1)
self.frame_rate = QSpinBox()
self.frame_rate.setRange(1, 120)
self.frame_rate.setValue(30)
basic_form.addRow("数据集目录:", dir_layout)
basic_form.addRow("任务名称:", self.task_name)
basic_form.addRow("起始集索引:", self.episode_idx)
basic_form.addRow("最大时间步:", self.max_timesteps)
basic_form.addRow("集数:", self.num_episodes)
basic_form.addRow("帧率:", self.frame_rate)
basic_group.setLayout(basic_form)
config_layout.addWidget(basic_group)
# 相机话题组
camera_group = QGroupBox("相机话题")
camera_form = QFormLayout()
self.img_front_topic = QLineEdit('/camera_f/color/image_raw')
self.img_left_topic = QLineEdit('/camera_l/color/image_raw')
self.img_right_topic = QLineEdit('/camera_r/color/image_raw')
self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw')
self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw')
self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw')
self.use_depth_image = QCheckBox("使用深度图像")
camera_form.addRow("前置相机:", self.img_front_topic)
camera_form.addRow("左腕相机:", self.img_left_topic)
camera_form.addRow("右腕相机:", self.img_right_topic)
camera_form.addRow("前置深度:", self.img_front_depth_topic)
camera_form.addRow("左腕深度:", self.img_left_depth_topic)
camera_form.addRow("右腕深度:", self.img_right_depth_topic)
camera_form.addRow("", self.use_depth_image)
camera_group.setLayout(camera_form)
config_layout.addWidget(camera_group)
# 机器人话题组
robot_group = QGroupBox("机器人话题")
robot_form = QFormLayout()
self.master_arm_left_topic = QLineEdit('/master/joint_left')
self.master_arm_right_topic = QLineEdit('/master/joint_right')
self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left')
self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right')
self.robot_base_topic = QLineEdit('/odom')
self.use_robot_base = QCheckBox("使用机器人底盘")
robot_form.addRow("主左臂:", self.master_arm_left_topic)
robot_form.addRow("主右臂:", self.master_arm_right_topic)
robot_form.addRow("从左臂:", self.puppet_arm_left_topic)
robot_form.addRow("从右臂:", self.puppet_arm_right_topic)
robot_form.addRow("底盘:", self.robot_base_topic)
robot_form.addRow("", self.use_robot_base)
robot_group.setLayout(robot_form)
config_layout.addWidget(robot_group)
# 相机名称配置
camera_names_group = QGroupBox("相机名称")
camera_names_layout = QVBoxLayout()
self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
self.camera_checkboxes = {}
for cam_name in self.camera_names:
self.camera_checkboxes[cam_name] = QCheckBox(cam_name)
self.camera_checkboxes[cam_name].setChecked(True)
camera_names_layout.addWidget(self.camera_checkboxes[cam_name])
camera_names_group.setLayout(camera_names_layout)
config_layout.addWidget(camera_names_group)
# 保存配置按钮
self.save_config_button = QPushButton("保存配置")
self.save_config_button.clicked.connect(self.save_config)
config_layout.addWidget(self.save_config_button)
def setup_collection_tab(self):
collection_layout = QVBoxLayout(self.collection_tab)
# 当前配置展示
config_group = QGroupBox("当前配置")
self.config_text = QTextEdit()
self.config_text.setReadOnly(True)
config_layout = QVBoxLayout()
config_layout.addWidget(self.config_text)
config_group.setLayout(config_layout)
collection_layout.addWidget(config_group)
# 操作按钮
buttons_layout = QHBoxLayout()
self.start_button = QPushButton("开始收集")
self.start_button.setIcon(QIcon.fromTheme("media-playback-start"))
self.start_button.clicked.connect(self.start_collection)
self.stop_button = QPushButton("停止")
self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop"))
self.stop_button.setEnabled(False)
self.stop_button.clicked.connect(self.stop_collection)
buttons_layout.addWidget(self.start_button)
buttons_layout.addWidget(self.stop_button)
collection_layout.addLayout(buttons_layout)
# 进度条
self.progress_bar = QProgressBar()
self.progress_bar.setValue(0)
collection_layout.addWidget(self.progress_bar)
# 日志输出
log_group = QGroupBox("操作日志")
self.log_text = QTextEdit()
self.log_text.setReadOnly(True)
log_layout = QVBoxLayout()
log_layout.addWidget(self.log_text)
log_group.setLayout(log_layout)
collection_layout.addWidget(log_group)
def browse_dataset_dir(self):
directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text())
if directory:
self.dataset_dir.setText(directory)
def save_config(self):
# 更新配置显示
selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()]
config_text = f"""
任务名称: {self.task_name.text()}
数据集目录: {self.dataset_dir.text()}
起始集索引: {self.episode_idx.value()}
最大时间步: {self.max_timesteps.value()}
集数: {self.num_episodes.value()}
帧率: {self.frame_rate.value()}
使用深度图像: {"" if self.use_depth_image.isChecked() else ""}
使用机器人底盘: {"" if self.use_robot_base.isChecked() else ""}
相机: {', '.join(selected_cameras)}
"""
self.config_text.setText(config_text)
self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡
QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集")
def start_collection(self):
if not self.task_name.text():
QMessageBox.warning(self, "配置错误", "请输入有效的任务名称")
return
# 构建参数
args = argparse.Namespace(
dataset_dir=self.dataset_dir.text(),
task_name=self.task_name.text(),
episode_idx=self.episode_idx.value(),
max_timesteps=self.max_timesteps.value(),
num_episodes=self.num_episodes.value(),
camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()],
img_front_topic=self.img_front_topic.text(),
img_left_topic=self.img_left_topic.text(),
img_right_topic=self.img_right_topic.text(),
img_front_depth_topic=self.img_front_depth_topic.text(),
img_left_depth_topic=self.img_left_depth_topic.text(),
img_right_depth_topic=self.img_right_depth_topic.text(),
master_arm_left_topic=self.master_arm_left_topic.text(),
master_arm_right_topic=self.master_arm_right_topic.text(),
puppet_arm_left_topic=self.puppet_arm_left_topic.text(),
puppet_arm_right_topic=self.puppet_arm_right_topic.text(),
robot_base_topic=self.robot_base_topic.text(),
use_robot_base=self.use_robot_base.isChecked(),
use_depth_image=self.use_depth_image.isChecked(),
frame_rate=self.frame_rate.value()
)
# 更新UI状态
self.start_button.setEnabled(False)
self.stop_button.setEnabled(True)
self.progress_bar.setValue(0)
self.log_text.clear()
self.log_text.append("正在初始化数据收集...\n")
# 创建并启动线程
self.collection_thread = DataCollectionThread(args)
self.collection_thread.update_signal.connect(self.update_log)
self.collection_thread.progress_signal.connect(self.update_progress)
self.collection_thread.finish_signal.connect(self.collection_finished)
self.collection_thread.start()
def stop_collection(self):
if self.collection_thread and self.collection_thread.isRunning():
self.log_text.append("正在停止数据收集...\n")
self.collection_thread.stop()
self.collection_thread = None
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
@pyqtSlot(str)
def update_log(self, message):
self.log_text.append(message)
# 自动滚动到底部
cursor = self.log_text.textCursor()
cursor.movePosition(QTextCursor.End)
self.log_text.setTextCursor(cursor)
@pyqtSlot(int)
def update_progress(self, value):
self.progress_bar.setValue(value)
@pyqtSlot(bool, str)
def collection_finished(self, success, message):
self.start_button.setEnabled(True)
self.stop_button.setEnabled(False)
if success:
QMessageBox.information(self, "完成", message)
else:
QMessageBox.warning(self, "出错", f"数据收集失败: {message}")
# 更新episode_idx值
if success and self.num_episodes.value() > 0:
self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value())
def main():
app = QApplication(sys.argv)
window = AlohaDataCollectionGUI()
window.show()
sys.exit(app.exec_())
if __name__ == "__main__":
main()

View File

@@ -1,462 +0,0 @@
import logging
import time
from dataclasses import asdict
from pprint import pformat
from pprint import pprint
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.common.robot_devices.control_utils import (
# init_keyboard_listener,
record_episode,
stop_recording,
is_headless
)
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.common.utils.utils import get_safe_torch_device
from contextlib import nullcontext
from copy import copy
import torch
import rospy
import cv2
from lerobot.configs import parser
from agilex_robot import AgilexRobot
########################################################################################
# Control modes
########################################################################################
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
# if not robot.is_connected:
# robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
rate = rospy.Rate(fps)
print_flag = True
while timestamp < control_time_s and not rospy.is_shutdown():
# print(timestamp < control_time_s)
# print(rospy.is_shutdown())
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step()
if observation is None or action is None:
if print_flag:
print("sync data fail, retrying...\n")
print_flag = False
rate.sleep()
continue
else:
# pass
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# if display_cameras and not is_headless():
# image_keys = [key for key in observation if "image" in key]
# for key in image_keys:
# if "depth" in key:
# pass
# else:
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
# print(1)
# cv2.waitKey(1)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080可以根据实际情况调整
screen_width = 1920
screen_height = 1080
# 计算窗口的排列方式
num_images = len(image_keys)
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
columns = min(num_images, max_columns) # 实际使用的列数
# 遍历所有图像键并显示
for idx, key in enumerate(image_keys):
if "depth" in key:
continue # 跳过深度图像
# 将图像从 RGB 转换为 BGR 格式
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
# 创建窗口
cv2.imshow(key, image)
# 计算窗口位置
window_width = 640
window_height = 480
row = idx // max_columns
col = idx % max_columns
x_position = col * window_width
y_position = row * window_height
# 移动窗口到指定位置
cv2.moveWindow(key, x_position, y_position)
# 等待 1 毫秒以处理事件
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["record_start"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
events["record_start"] = False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.up:
print("Up arrow pressed. Start data recording...")
events["record_start"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def stop_recording(robot, listener, display_cameras):
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
fps,
single_task,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
def record(
robot,
cfg
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
# Create empty dataset or load existing saved episodes
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# policy = None
# if not robot.is_connected:
# robot.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
# if has_method(robot, "teleop_safety_stop"):
# robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
# if events["record_start"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset
def replay(
robot: AgilexRobot,
cfg,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
# if not robot.is_connected:
# robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
# log_control_info(robot, dt_s, fps=cfg.fps)
import argparse
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "./data5"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 100
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "move the bottle from the right to the scale right"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy = None
# args.teleoprate = True
args.control_type = "record"
# args.control_type = "replay"
return args
# @parser.wrap()
def control_robot(cfg):
from rosrobot_factory import RobotFactory
# 使用工厂模式创建机器人实例
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# 使用工厂模式创建机器人实例
# robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)
# capture = robot.capture_observation()
# import torch
# torch.save(capture, "test.pt")
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
# device='cpu')
# robot.send_action(action.squeeze(0))
# print()

View File

@@ -1,3 +0,0 @@
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
# export LD_LIBRARY_PATH=/home/ubuntu/miniconda3/envs/lerobot/lib:$LD_LIBRARY_PATH
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH

View File

@@ -1,769 +0,0 @@
#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""
import torch
import numpy as np
import os
import pickle
import argparse
from einops import rearrange
import collections
from collections import deque
import rospy
from std_msgs.msg import Header
from geometry_msgs.msg import Twist
from sensor_msgs.msg import JointState, Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
import time
import threading
import math
import threading
import sys
sys.path.append("./")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']}
inference_thread = None
inference_lock = threading.Lock()
inference_actions = None
inference_timestep = None
def actions_interpolation(args, pre_action, actions, stats):
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
result = [pre_action]
post_action = post_process(actions[0])
# print("pre_action:", pre_action[7:])
# print("actions_interpolation1:", post_action[:, 7:])
max_diff_index = 0
max_diff = -1
for i in range(post_action.shape[0]):
diff = 0
for j in range(pre_action.shape[0]):
if j == 6 or j == 13:
continue
diff += math.fabs(pre_action[j] - post_action[i][j])
if diff > max_diff:
max_diff = diff
max_diff_index = i
for i in range(max_diff_index, post_action.shape[0]):
step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])])
inter = np.linspace(result[-1], post_action[i], step+2)
result.extend(inter[1:])
while len(result) < args.chunk_size+1:
result.append(result[-1])
result = np.array(result)[1:args.chunk_size+1]
# print("actions_interpolation2:", result.shape, result[:, 7:])
result = pre_process(result)
result = result[np.newaxis, :]
return result
def get_model_config(args):
# 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。
set_seed(1)
# 如果是ACT策略
# fixed parameters
if args.policy_class == 'ACT':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': args.chunk_size, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base,
'kl_weight': args.kl_weight, # kl散度权重
'hidden_dim': args.hidden_dim, # 隐藏层维度
'dim_feedforward': args.dim_feedforward,
'enc_layers': args.enc_layers,
'dec_layers': args.dec_layers,
'nheads': args.nheads,
'dropout': args.dropout,
'pre_norm': args.pre_norm
}
elif args.policy_class == 'CNNMLP':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': 1, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base
}
elif args.policy_class == 'Diffusion':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': args.chunk_size, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base,
'observation_horizon': args.observation_horizon,
'action_horizon': args.action_horizon,
'num_inference_timesteps': args.num_inference_timesteps,
'ema_power': args.ema_power
}
else:
raise NotImplementedError
config = {
'ckpt_dir': args.ckpt_dir,
'ckpt_name': args.ckpt_name,
'ckpt_stats_name': args.ckpt_stats_name,
'episode_len': args.max_publish_step,
'state_dim': args.state_dim,
'policy_class': args.policy_class,
'policy_config': policy_config,
'temporal_agg': args.temporal_agg,
'camera_names': task_config['camera_names'],
}
return config
def make_policy(policy_class, policy_config):
if policy_class == 'ACT':
policy = ACTPolicy(policy_config)
elif policy_class == 'CNNMLP':
policy = CNNMLPPolicy(policy_config)
elif policy_class == 'Diffusion':
policy = DiffusionPolicy(policy_config)
else:
raise NotImplementedError
return policy
def get_image(observation, camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def get_depth_image(observation, camera_names):
curr_images = []
for cam_name in camera_names:
curr_images.append(observation['images_depth'][cam_name])
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def inference_process(args, config, ros_operator, policy, stats, t, pre_action):
global inference_lock
global inference_actions
global inference_timestep
print_flag = True
pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"]
rate = rospy.Rate(args.publish_rate)
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail")
print_flag = False
rate.sleep()
continue
print_flag = True
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base) = result
obs = collections.OrderedDict()
image_dict = dict()
image_dict[config['camera_names'][0]] = img_front
image_dict[config['camera_names'][1]] = img_left
image_dict[config['camera_names'][2]] = img_right
obs['images'] = image_dict
if args.use_depth_image:
image_depth_dict = dict()
image_depth_dict[config['camera_names'][0]] = img_front_depth
image_depth_dict[config['camera_names'][1]] = img_left_depth
image_depth_dict[config['camera_names'][2]] = img_right_depth
obs['images_depth'] = image_depth_dict
obs['qpos'] = np.concatenate(
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
obs['qvel'] = np.concatenate(
(np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
obs['effort'] = np.concatenate(
(np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
if args.use_robot_base:
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0)
else:
obs['base_vel'] = [0.0, 0.0]
# qpos_numpy = np.array(obs['qpos'])
# 归一化处理qpos 并转到cuda
qpos = pre_pos_process(obs['qpos'])
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
# 当前图像curr_image获取图像
curr_image = get_image(obs, config['camera_names'])
curr_depth_image = None
if args.use_depth_image:
curr_depth_image = get_depth_image(obs, config['camera_names'])
start_time = time.time()
all_actions = policy(curr_image, curr_depth_image, qpos)
end_time = time.time()
print("model cost time: ", end_time -start_time)
inference_lock.acquire()
inference_actions = all_actions.cpu().detach().numpy()
if pre_action is None:
pre_action = obs['qpos']
# print("obs['qpos']:", obs['qpos'][7:])
if args.use_actions_interpolation:
inference_actions = actions_interpolation(args, pre_action, inference_actions, stats)
inference_timestep = t
inference_lock.release()
break
def model_inference(args, config, ros_operator, save_episode=True):
global inference_lock
global inference_actions
global inference_timestep
global inference_thread
set_seed(1000)
# 1 创建模型数据 继承nn.Module
policy = make_policy(config['policy_class'], config['policy_config'])
# print("model structure\n", policy.model)
# 2 加载模型权重
ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name'])
state_dict = torch.load(ckpt_path)
new_state_dict = {}
for key, value in state_dict.items():
if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]:
continue
if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]:
continue
new_state_dict[key] = value
loading_status = policy.deserialize(new_state_dict)
if not loading_status:
print("ckpt path not exist")
return False
# 3 模型设置为cuda模式和验证模式
policy.cuda()
policy.eval()
# 4 加载统计值
stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name'])
# 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 数据预处理和后处理函数定义
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
max_publish_step = config['episode_len']
chunk_size = config['policy_config']['chunk_size']
# 发布基础的姿态
left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
ros_operator.puppet_arm_publish_continuous(left0, right0)
input("Enter any key to continue :")
ros_operator.puppet_arm_publish_continuous(left1, right1)
action = None
# 推理
with torch.inference_mode():
while True and not rospy.is_shutdown():
# 每个回合的步数
t = 0
max_t = 0
rate = rospy.Rate(args.publish_rate)
if config['temporal_agg']:
all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']])
while t < max_publish_step and not rospy.is_shutdown():
# start_time = time.time()
# query policy
if config['policy_class'] == "ACT":
if t >= max_t:
pre_action = action
inference_thread = threading.Thread(target=inference_process,
args=(args, config, ros_operator,
policy, stats, t, pre_action))
inference_thread.start()
inference_thread.join()
inference_lock.acquire()
if inference_actions is not None:
inference_thread = None
all_actions = inference_actions
inference_actions = None
max_t = t + args.pos_lookahead_step
if config['temporal_agg']:
all_time_actions[[t], t:t + chunk_size] = all_actions
inference_lock.release()
if config['temporal_agg']:
actions_for_curr_step = all_time_actions[:, t]
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = exp_weights[:, np.newaxis]
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True)
else:
if args.pos_lookahead_step != 0:
raw_action = all_actions[:, t % args.pos_lookahead_step]
else:
raw_action = all_actions[:, t % chunk_size]
else:
raise NotImplementedError
action = post_process(raw_action[0])
left_action = action[:7] # 取7维度
right_action = action[7:14]
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
if args.use_robot_base:
vel_action = action[14:16]
ros_operator.robot_base_publish(vel_action)
t += 1
# end_time = time.time()
# print("publish: ", t)
# print("time:", end_time - start_time)
# print("left_action:", left_action)
# print("right_action:", right_action)
rate.sleep()
class RosOperator:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.puppet_arm_left_publisher = None
self.puppet_arm_right_publisher = None
self.robot_base_publisher = None
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_lock = None
self.args = args
self.ctrl_state = False
self.ctrl_state_lock = threading.Lock()
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
self.puppet_arm_publish_lock = threading.Lock()
self.puppet_arm_publish_lock.acquire()
def puppet_arm_publish(self, left, right):
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right
self.puppet_arm_right_publisher.publish(joint_state_msg)
def robot_base_publish(self, vel):
vel_msg = Twist()
vel_msg.linear.x = vel[0]
vel_msg.linear.y = 0
vel_msg.linear.z = 0
vel_msg.angular.x = 0
vel_msg.angular.y = 0
vel_msg.angular.z = vel[1]
self.robot_base_publisher.publish(vel_msg)
def puppet_arm_publish_continuous(self, left, right):
rate = rospy.Rate(self.args.publish_rate)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
flag = True
step = 0
while flag and not rospy.is_shutdown():
if self.puppet_arm_publish_lock.acquire(False):
return
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
flag = False
for i in range(len(left)):
if left_diff[i] < self.args.arm_steps_length[i]:
left_arm[i] = left[i]
else:
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
flag = True
for i in range(len(right)):
if right_diff[i] < self.args.arm_steps_length[i]:
right_arm[i] = right[i]
else:
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
flag = True
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left_arm
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right_arm
self.puppet_arm_right_publisher.publish(joint_state_msg)
step += 1
print("puppet_arm_publish_continuous:", step)
rate.sleep()
def puppet_arm_publish_linear(self, left, right):
num_step = 100
rate = rospy.Rate(200)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
traj_left_list = np.linspace(left_arm, left, num_step)
traj_right_list = np.linspace(right_arm, right, num_step)
for i in range(len(traj_left_list)):
traj_left = traj_left_list[i]
traj_right = traj_right_list[i]
traj_left[-1] = left[-1]
traj_right[-1] = right[-1]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = traj_left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = traj_right
self.puppet_arm_right_publisher.publish(joint_state_msg)
rate.sleep()
def puppet_arm_publish_continuous_thread(self, left, right):
if self.puppet_arm_publish_thread is not None:
self.puppet_arm_publish_lock.release()
self.puppet_arm_publish_thread.join()
self.puppet_arm_publish_lock.acquire(False)
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
self.puppet_arm_publish_thread.start()
def get_frame(self):
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def ctrl_callback(self, msg):
self.ctrl_state_lock.acquire()
self.ctrl_state = msg.data
self.ctrl_state_lock.release()
def get_ctrl_state(self):
self.ctrl_state_lock.acquire()
state = self.ctrl_state
self.ctrl_state_lock.release()
return state
def init_ros(self):
rospy.init_node('joint_state_publisher', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False)
parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False)
parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False)
parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False)
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False)
parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False)
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False)
parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False)
parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False)
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False)
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features", required=False)
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False)
parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False)
parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False)
parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False)
parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False)
parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False)
parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False)
parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False)
parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False)
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False)
parser.add_argument('--pre_norm', action='store_true', required=False)
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
default='/camera_f/depth/image_raw', required=False)
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
default='/camera_l/depth/image_raw', required=False)
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
default='/camera_r/depth/image_raw', required=False)
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
default='/master/joint_left', required=False)
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
default='/master/joint_right', required=False)
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/odom_raw', required=False)
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
default='/cmd_vel', required=False)
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
default=False, required=False)
parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate',
default=40, required=False)
parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step',
default=0, required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size',
default=32, required=False)
parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length',
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation',
default=False, required=False)
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
default=False, required=False)
# for Diffusion
parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False)
parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False)
parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False)
parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False)
args = parser.parse_args()
return args
def main():
args = get_arguments()
ros_operator = RosOperator(args)
config = get_model_config(args)
model_inference(args, config, ros_operator, save_episode=True)
if __name__ == '__main__':
main()
# python act/inference.py --ckpt_dir ~/train0314/

View File

@@ -1,33 +0,0 @@
import pandas as pd
def read_and_print_parquet_row(file_path, row_index=0):
"""
读取Parquet文件并打印指定行的数据
参数:
file_path (str): Parquet文件路径
row_index (int): 要打印的行索引默认为第0行
"""
try:
# 读取Parquet文件
df = pd.read_parquet(file_path)
# 检查行索引是否有效
if row_index >= len(df):
print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)")
return
# 打印指定行数据
print(f"文件: {file_path}")
print(f"{row_index} 行数据:\n{'-'*30}")
print(df.iloc[row_index])
except FileNotFoundError:
print(f"错误: 文件 {file_path} 不存在")
except Exception as e:
print(f"读取失败: {str(e)}")
# 示例用法
if __name__ == "__main__":
file_path = "example.parquet" # 替换为你的Parquet文件路径
read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行

207
collect_data/replay_data.py Normal file → Executable file
View File

@@ -10,24 +10,63 @@ from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from geometry_msgs.msg import Twist
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import sys
sys.path.append("./")
def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
if not os.path.isfile(dataset_path):
print(f'Dataset does not exist at \n{dataset_path}\n')
exit()
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
compressed = root.attrs.get('compress', False)
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
if 'effort' in root.keys():
effort = root['/observations/effort'][()]
else:
effort = None
action = root['/action'][()]
base_action = root['/base_action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
if compressed:
compress_len = root['/compress_len'][()]
if compressed:
for cam_id, cam_name in enumerate(image_dict.keys()):
# un-pad and uncompress
padded_compressed_image_list = image_dict[cam_name]
image_list = []
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
image_len = int(compress_len[cam_id, frame_id])
compressed_image = padded_compressed_image
image = cv2.imdecode(compressed_image, 1)
image_list.append(image)
image_dict[cam_name] = image_list
return qpos, qvel, effort, action, base_action, image_dict
def main(args):
rospy.init_node("replay_node")
bridge = CvBridge()
# img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
# img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
# img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
# puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
# puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
# robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
# dataset_dir = args.dataset_dir
@@ -35,78 +74,130 @@ def main(args):
# task_name = args.task_name
# dataset_name = f'episode_{episode_idx}'
dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode])
actions = dataset.hf_dataset.select_columns("action")
velocitys = dataset.hf_dataset.select_columns("observation.velocity")
efforts = dataset.hf_dataset.select_columns("observation.effort")
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
twist_msg = Twist()
rate = rospy.Rate(args.fps)
rate = rospy.Rate(args.frame_rate)
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
actions = [[ 0.1277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.3238, -0.1094,
0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]]
last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647]
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
rate = rospy.Rate(50)
for idx in range(len(actions)):
action = actions[idx]['action'].detach().cpu().numpy()
velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy()
effort = efforts[idx]['observation.effort'].detach().cpu().numpy()
if(rospy.is_shutdown()):
break
new_actions = np.linspace(last_action, action, 5) # 插值
new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值
new_efforts = np.linspace(last_effort, effort, 5) # 插值
last_action = action
last_velocity = velocity
last_effort = effort
for act in new_actions:
print(np.round(act[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
if not args.only_pub_master:
last_action = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279, 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
rate = rospy.Rate(100)
for action in actions:
if(rospy.is_shutdown()):
break
joint_state_msg.position = act[:7]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
new_actions = np.linspace(last_action, action, 50) # 插值
last_action = action
for act in new_actions:
print(np.round(act[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = act[:7]
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = act[7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
else:
i = 0
while(not rospy.is_shutdown() and i < len(actions)):
print("left: ", np.round(qposs[i][:7], 4), " right: ", np.round(qposs[i][7:], 4))
cam_names = [k for k in image_dicts.keys()]
image0 = image_dicts[cam_names[0]][i]
image0 = image0[:, :, [2, 1, 0]] # swap B and R channel
image1 = image_dicts[cam_names[1]][i]
image1 = image1[:, :, [2, 1, 0]] # swap B and R channel
image2 = image_dicts[cam_names[2]][i]
image2 = image2[:, :, [2, 1, 0]] # swap B and R channel
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = actions[i][:7]
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = act[7:]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[7:]
joint_state_msg.position = actions[i][7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
joint_state_msg.position = qposs[i][:7]
puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = qposs[i][7:]
puppet_arm_right_publisher.publish(joint_state_msg)
img_front_publisher.publish(bridge.cv2_to_imgmsg(image0, "bgr8"))
img_left_publisher.publish(bridge.cv2_to_imgmsg(image1, "bgr8"))
img_right_publisher.publish(bridge.cv2_to_imgmsg(image2, "bgr8"))
twist_msg.linear.x = base_actions[i][0]
twist_msg.angular.z = base_actions[i][1]
robot_base_publisher.publish(twist_msg)
i += 1
rate.sleep()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
# default='/master/joint_left', required=False)
# parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
# default='/master/joint_right', required=False)
args = parser.parse_args()
args.repo_id = "tangger/test"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
args.episode = 1 # replay episode
args.master_arm_left_topic = "/master/joint_left"
args.master_arm_right_topic = "/master/joint_right"
args.fps = 30
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=False)
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
default="aloha_mobile_dummy", required=False)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
parser.add_argument('--camera_names', action='store', type=str, help='camera_names',
default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False)
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
default='/master/joint_left', required=False)
parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
default='/master/joint_right', required=False)
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/cmd_vel', required=False)
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
default=False, required=False)
parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate',
default=30, required=False)
parser.add_argument('--only_pub_master', action='store_true', help='only_pub_master',required=False)
args = parser.parse_args()
main(args)
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0

10
collect_data/requirements.txt Executable file
View File

@@ -0,0 +1,10 @@
opencv-python==4.9.0.80
matplotlib==3.7.5
h5py==3.8.0
dm-env==1.6
numpy==1.23.4
pyyaml
rospkg==1.5.0
catkin-pkg==1.0.0
empy==3.3.4
PyQt5==5.15.10

View File

@@ -1,70 +0,0 @@
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.robot_devices.utils import busy_wait
import time
import argparse
from agilex_robot import AgilexRobot
import torch
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "tangger/test"
args.root = "./data2"
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 50
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "test test"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy= None
args.teleoprate = False
return args
cfg = get_arguments()
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
inference_time_s = 360
fps = 30
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
policy = ACTPolicy.from_pretrained(ckpt_path)
policy.to(device)
for _ in range(inference_time_s * fps):
start_time = time.perf_counter()
# Read the follower state and access the frames from the cameras
observation = robot.capture_observation()
if observation is None:
print("Observation is None, skipping...")
continue
# Convert to pytorch format: channel first and float32 in [0,1]
# with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
# Order the robot to move
robot.send_action(action)
dt_s = time.perf_counter() - start_time
busy_wait(1 / fps - dt_s)

272
collect_data/utils.py Normal file
View File

@@ -0,0 +1,272 @@
import cv2
import numpy as np
import h5py
import time
def display_camera_grid(image_dict, grid_shape=None, window_name="MindRobot-V1 Data Collection", scale=1.0):
"""
显示多摄像头画面(保持原始比例,但可整体缩放)
参数:
image_dict: {摄像头名称: 图像numpy数组}
grid_shape: (行, 列) 布局None自动计算
window_name: 窗口名称
scale: 整体显示缩放比例0.5表示显示为原尺寸的50%
"""
# 输入验证和数据处理(保持原代码不变)
if not isinstance(image_dict, dict):
raise TypeError("输入必须是字典类型")
valid_data = []
for name, img in image_dict.items():
if not isinstance(img, np.ndarray):
continue
if img.dtype != np.uint8:
img = img.astype(np.uint8)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
elif img.shape[2] == 4:
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
elif img.shape[2] == 3:
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
valid_data.append((name, img))
if not valid_data:
print("错误: 没有有效的图像可显示!")
return None
# 自动计算网格布局
num_valid = len(valid_data)
if grid_shape is None:
grid_shape = (1, num_valid) if num_valid <= 3 else (2, int(np.ceil(num_valid/2)))
rows, cols = grid_shape
# 计算每行/列的最大尺寸
row_heights = [0]*rows
col_widths = [0]*cols
for i, (_, img) in enumerate(valid_data[:rows*cols]):
r, c = i//cols, i%cols
row_heights[r] = max(row_heights[r], img.shape[0])
col_widths[c] = max(col_widths[c], img.shape[1])
# 计算画布总尺寸(应用整体缩放)
canvas_h = int(sum(row_heights) * scale)
canvas_w = int(sum(col_widths) * scale)
# 创建画布
canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)
# 计算每个子画面的显示区域
row_pos = [0] + [int(sum(row_heights[:i+1])*scale) for i in range(rows)]
col_pos = [0] + [int(sum(col_widths[:i+1])*scale) for i in range(cols)]
# 填充图像
for i, (name, img) in enumerate(valid_data[:rows*cols]):
r, c = i//cols, i%cols
# 计算当前图像的显示区域
x1, x2 = col_pos[c], col_pos[c+1]
y1, y2 = row_pos[r], row_pos[r+1]
# 计算当前图像的缩放后尺寸
display_h = int(img.shape[0] * scale)
display_w = int(img.shape[1] * scale)
# 缩放图像(保持比例)
resized_img = cv2.resize(img, (display_w, display_h))
# 放置到画布
canvas[y1:y1+display_h, x1:x1+display_w] = resized_img
# 添加标签(按比例缩放字体)
font_scale = 0.8 *scale
thickness = max(2, int(2 * scale))
cv2.putText(canvas, name, (x1+10, y1+30),
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness)
# 显示窗口(自动适应屏幕)
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
cv2.imshow(window_name, canvas)
cv2.resizeWindow(window_name, canvas_w, canvas_h)
cv2.waitKey(1)
return canvas
# 保存数据函数
def save_data(args, timesteps, actions, dataset_path):
# 数据字典
data_size = len(actions)
data_dict = {
# 一个是奖励里面的qposqvel effort ,一个是实际发的acition
'/observations/qpos': [],
'/observations/qvel': [],
'/observations/effort': [],
'/action': [],
'/base_action': [],
# '/base_action_t265': [],
}
# 相机字典 观察的图像
for cam_name in args.camera_names:
data_dict[f'/observations/images/{cam_name}'] = []
if args.use_depth_image:
data_dict[f'/observations/images_depth/{cam_name}'] = []
# len(action): max_timesteps, len(time_steps): max_timesteps + 1
# 动作长度 遍历动作
while actions:
# 循环弹出一个队列
action = actions.pop(0) # 动作 当前动作
ts = timesteps.pop(0) # 奖励 前一帧
# 往字典里面添值
# Timestep返回的qposqvel,effort
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/observations/effort'].append(ts.observation['effort'])
# 实际发的action
data_dict['/action'].append(action)
data_dict['/base_action'].append(ts.observation['base_vel'])
# 相机数据
# data_dict['/base_action_t265'].append(ts.observation['base_vel_t265'])
for cam_name in args.camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
if args.use_depth_image:
data_dict[f'/observations/images_depth/{cam_name}'].append(ts.observation['images_depth'][cam_name])
t0 = time.time()
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
# 文本的属性:
# 1 是否仿真
# 2 图像是否压缩
#
root.attrs['sim'] = False
root.attrs['compress'] = False
# 创建一个新的组observations观测状态组
# 图像组
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in args.camera_names:
_ = image.create_dataset(cam_name, (data_size, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
if args.use_depth_image:
image_depth = obs.create_group('images_depth')
for cam_name in args.camera_names:
_ = image_depth.create_dataset(cam_name, (data_size, 480, 640), dtype='uint16',
chunks=(1, 480, 640), )
_ = obs.create_dataset('qpos', (data_size, 14))
_ = obs.create_dataset('qvel', (data_size, 14))
_ = obs.create_dataset('effort', (data_size, 14))
_ = root.create_dataset('action', (data_size, 14))
_ = root.create_dataset('base_action', (data_size, 2))
# data_dict write into h5py.File
for name, array in data_dict.items():
root[name][...] = array
print(f'\033[32m\nSaving: {time.time() - t0:.1f} secs. %s \033[0m\n'%dataset_path)
def is_headless():
"""
Check if the environment is headless (no display available).
Returns:
bool: True if the environment is headless, False otherwise.
"""
try:
import tkinter as tk
root = tk.Tk()
root.withdraw()
root.update()
root.destroy()
return False
except:
return True
def init_keyboard_listener():
"""
Initialize keyboard listener for control events with new key mappings:
- Left arrow: Start data recording
- Right arrow: Save current data
- Down arrow: Discard current data
- Up arrow: Replay current data
- ESC: Early termination
Returns:
tuple: (listener, events) - Keyboard listener and events dictionary
"""
events = {
"exit_early": False,
"record_start": False,
"save_data": False,
"discard_data": False,
"replay_data": False
}
if is_headless():
print(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
return None, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.left:
print("← Left arrow: STARTING data recording...")
events.update({
"record_start": True,
"exit_early": False,
"save_data": False,
"discard_data": False
})
elif key == keyboard.Key.right:
print("→ Right arrow: SAVING current data...")
events.update({
"save_data": True,
"exit_early": False,
"record_start": False
})
elif key == keyboard.Key.down:
print("↓ Down arrow: DISCARDING current data...")
events.update({
"discard_data": True,
"exit_early": False,
"record_start": False
})
elif key == keyboard.Key.up:
print("↑ Up arrow: REPLAYING current data...")
events.update({
"replay_data": True,
"exit_early": False
})
elif key == keyboard.Key.esc:
print("ESC: EARLY TERMINATION requested")
events.update({
"exit_early": True,
"record_start": False
})
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events

View File

@@ -0,0 +1,159 @@
#coding=utf-8
import os
import numpy as np
import cv2
import h5py
import argparse
import matplotlib.pyplot as plt
DT = 0.02
# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
JOINT_NAMES = ["joint0", "joint1", "joint2", "joint3", "joint4", "joint5"]
STATE_NAMES = JOINT_NAMES + ["gripper"]
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]
def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
if not os.path.isfile(dataset_path):
print(f'Dataset does not exist at \n{dataset_path}\n')
exit()
with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
compressed = root.attrs.get('compress', False)
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
if 'effort' in root.keys():
effort = root['/observations/effort'][()]
else:
effort = None
action = root['/action'][()]
base_action = root['/base_action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
if compressed:
compress_len = root['/compress_len'][()]
if compressed:
for cam_id, cam_name in enumerate(image_dict.keys()):
# un-pad and uncompress
padded_compressed_image_list = image_dict[cam_name]
image_list = []
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
image_len = int(compress_len[cam_id, frame_id])
compressed_image = padded_compressed_image
image = cv2.imdecode(compressed_image, 1)
image_list.append(image)
image_dict[cam_name] = image_list
return qpos, qvel, effort, action, base_action, image_dict
def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
task_name = args['task_name']
dataset_name = f'episode_{episode_idx}'
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
print('hdf5 loaded!!')
save_videos(image_dict, action, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
def save_videos(video, actions, dt, video_path=None):
cam_names = list(video.keys())
all_cam_videos = []
for cam_name in cam_names:
all_cam_videos.append(video[cam_name])
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
n_frames, h, w, _ = all_cam_videos.shape
fps = int(1 / dt)
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
image = all_cam_videos[t]
image = image[:, :, [2, 1, 0]] # swap B and R channel
cv2.imshow("images",image)
cv2.waitKey(30)
print("episode_id: ", t, "left: ", np.round(actions[t][:7], 3), "right: ", np.round(actions[t][7:], 3), "\n")
out.write(image)
out.release()
print(f'Saved video to: {video_path}')
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
if label_overwrite:
label1, label2 = label_overwrite
else:
label1, label2 = 'State', 'Command'
qpos = np.array(qpos_list) # ts, dim
command = np.array(command_list)
num_ts, num_dim = qpos.shape
h, w = 2, num_dim
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
# plot joint state
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(qpos[:, dim_idx], label=label1, color='orangered')
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
# plot arm command
# for dim_idx in range(num_dim):
# ax = axs[dim_idx]
# ax.plot(command[:, dim_idx], label=label2)
# ax.legend()
if ylim:
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.set_ylim(ylim)
plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved qpos plot to: {plot_path}')
plt.close()
def visualize_base(readings, plot_path=None):
readings = np.array(readings) # ts, dim
num_ts, num_dim = readings.shape
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
# plot joint state
all_names = BASE_STATE_NAMES
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(readings[:, dim_idx], label='raw')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()
plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved effort plot to: {plot_path}')
plt.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
default="aloha_mobile_dummy", required=False)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
main(vars(parser.parse_args()))