Compare commits
8 Commits
91c2b7b0cb
...
25fb9c0d33
| Author | SHA1 | Date | |
|---|---|---|---|
| 25fb9c0d33 | |||
| 722de584d2 | |||
| a4fe5ee09a | |||
| 3df284ddd1 | |||
| 88885a6a25 | |||
| aa3920dd28 | |||
| e034881507 | |||
| d843a990a3 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,2 +1,5 @@
|
|||||||
cobot_magic/
|
cobot_magic/
|
||||||
librealsense/
|
librealsense/
|
||||||
|
data*/
|
||||||
|
outputs/
|
||||||
|
lerobot_datasets/
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data
|
|
||||||
|
|
||||||
python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1
|
|
||||||
Binary file not shown.
BIN
collect_data/__pycache__/aloha_mobile.cpython-310.pyc
Normal file
BIN
collect_data/__pycache__/aloha_mobile.cpython-310.pyc
Normal file
Binary file not shown.
BIN
collect_data/__pycache__/collect_data.cpython-310.pyc
Normal file
BIN
collect_data/__pycache__/collect_data.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
collect_data/__pycache__/utils.cpython-310.pyc
Normal file
BIN
collect_data/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
23
collect_data/aloha.yaml
Normal file
23
collect_data/aloha.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
camera_names:
|
||||||
|
- cam_high
|
||||||
|
- cam_left_wrist
|
||||||
|
- cam_right_wrist
|
||||||
|
dataset_dir: /home/ubuntu/LYT/lerobot_aloha/datasets/3camera
|
||||||
|
episode_idx: 0
|
||||||
|
frame_rate: 30
|
||||||
|
img_front_depth_topic: /camera_f/depth/image_raw
|
||||||
|
img_front_topic: /camera_f/color/image_raw
|
||||||
|
img_left_depth_topic: /camera_l/depth/image_raw
|
||||||
|
img_left_topic: /camera_l/color/image_raw
|
||||||
|
img_right_depth_topic: /camera_r/depth/image_raw
|
||||||
|
img_right_topic: /camera_r/color/image_raw
|
||||||
|
master_arm_left_topic: /master/joint_left
|
||||||
|
master_arm_right_topic: /master/joint_right
|
||||||
|
max_timesteps: 500
|
||||||
|
num_episodes: 50
|
||||||
|
puppet_arm_left_topic: /puppet/joint_left
|
||||||
|
puppet_arm_right_topic: /puppet/joint_right
|
||||||
|
robot_base_topic: /odom
|
||||||
|
task_name: aloha_mobile_dummy
|
||||||
|
use_depth_image: false
|
||||||
|
use_robot_base: false
|
||||||
305
collect_data/aloha_mobile.py
Normal file
305
collect_data/aloha_mobile.py
Normal file
@@ -0,0 +1,305 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import dm_env
|
||||||
|
|
||||||
|
import collections
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
import rospy
|
||||||
|
from sensor_msgs.msg import JointState
|
||||||
|
from sensor_msgs.msg import Image
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
from utils import display_camera_grid, save_data
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaRobotRos:
|
||||||
|
def __init__(self, args):
|
||||||
|
self.robot_base_deque = None
|
||||||
|
self.puppet_arm_right_deque = None
|
||||||
|
self.puppet_arm_left_deque = None
|
||||||
|
self.master_arm_right_deque = None
|
||||||
|
self.master_arm_left_deque = None
|
||||||
|
self.img_front_deque = None
|
||||||
|
self.img_right_deque = None
|
||||||
|
self.img_left_deque = None
|
||||||
|
self.img_front_depth_deque = None
|
||||||
|
self.img_right_depth_deque = None
|
||||||
|
self.img_left_depth_deque = None
|
||||||
|
self.bridge = None
|
||||||
|
self.args = args
|
||||||
|
self.init()
|
||||||
|
self.init_ros()
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
self.bridge = CvBridge()
|
||||||
|
self.img_left_deque = deque()
|
||||||
|
self.img_right_deque = deque()
|
||||||
|
self.img_front_deque = deque()
|
||||||
|
self.img_left_depth_deque = deque()
|
||||||
|
self.img_right_depth_deque = deque()
|
||||||
|
self.img_front_depth_deque = deque()
|
||||||
|
self.master_arm_left_deque = deque()
|
||||||
|
self.master_arm_right_deque = deque()
|
||||||
|
self.puppet_arm_left_deque = deque()
|
||||||
|
self.puppet_arm_right_deque = deque()
|
||||||
|
self.robot_base_deque = deque()
|
||||||
|
|
||||||
|
def get_frame(self):
|
||||||
|
print(len(self.img_left_deque), len(self.img_right_deque), len(self.img_front_deque),
|
||||||
|
len(self.img_left_depth_deque), len(self.img_right_depth_deque), len(self.img_front_depth_deque))
|
||||||
|
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
|
||||||
|
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
|
||||||
|
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
|
||||||
|
else:
|
||||||
|
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
|
||||||
|
|
||||||
|
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if len(self.master_arm_left_deque) == 0 or self.master_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if len(self.master_arm_right_deque) == 0 or self.master_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
||||||
|
return False
|
||||||
|
|
||||||
|
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_left_deque.popleft()
|
||||||
|
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
|
||||||
|
# print("img_left:", img_left.shape)
|
||||||
|
|
||||||
|
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_right_deque.popleft()
|
||||||
|
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
|
||||||
|
|
||||||
|
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_front_deque.popleft()
|
||||||
|
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
|
||||||
|
|
||||||
|
while self.master_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.master_arm_left_deque.popleft()
|
||||||
|
master_arm_left = self.master_arm_left_deque.popleft()
|
||||||
|
|
||||||
|
while self.master_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.master_arm_right_deque.popleft()
|
||||||
|
master_arm_right = self.master_arm_right_deque.popleft()
|
||||||
|
|
||||||
|
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.puppet_arm_left_deque.popleft()
|
||||||
|
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
||||||
|
|
||||||
|
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.puppet_arm_right_deque.popleft()
|
||||||
|
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
||||||
|
|
||||||
|
img_left_depth = None
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_left_depth_deque.popleft()
|
||||||
|
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
|
||||||
|
top, bottom, left, right = 40, 40, 0, 0
|
||||||
|
img_left_depth = cv2.copyMakeBorder(img_left_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||||
|
|
||||||
|
img_right_depth = None
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_right_depth_deque.popleft()
|
||||||
|
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
|
||||||
|
top, bottom, left, right = 40, 40, 0, 0
|
||||||
|
img_right_depth = cv2.copyMakeBorder(img_right_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||||
|
|
||||||
|
img_front_depth = None
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.img_front_depth_deque.popleft()
|
||||||
|
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
|
||||||
|
top, bottom, left, right = 40, 40, 0, 0
|
||||||
|
img_front_depth = cv2.copyMakeBorder(img_front_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||||
|
|
||||||
|
robot_base = None
|
||||||
|
if self.args.use_robot_base:
|
||||||
|
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
||||||
|
self.robot_base_deque.popleft()
|
||||||
|
robot_base = self.robot_base_deque.popleft()
|
||||||
|
|
||||||
|
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||||
|
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base)
|
||||||
|
|
||||||
|
def img_left_callback(self, msg):
|
||||||
|
if len(self.img_left_deque) >= 2000:
|
||||||
|
self.img_left_deque.popleft()
|
||||||
|
self.img_left_deque.append(msg)
|
||||||
|
|
||||||
|
def img_right_callback(self, msg):
|
||||||
|
if len(self.img_right_deque) >= 2000:
|
||||||
|
self.img_right_deque.popleft()
|
||||||
|
self.img_right_deque.append(msg)
|
||||||
|
|
||||||
|
def img_front_callback(self, msg):
|
||||||
|
if len(self.img_front_deque) >= 2000:
|
||||||
|
self.img_front_deque.popleft()
|
||||||
|
self.img_front_deque.append(msg)
|
||||||
|
|
||||||
|
def img_left_depth_callback(self, msg):
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
if len(self.img_left_depth_deque) >= 2000:
|
||||||
|
self.img_left_depth_deque.popleft()
|
||||||
|
self.img_left_depth_deque.append(msg)
|
||||||
|
|
||||||
|
def img_right_depth_callback(self, msg):
|
||||||
|
if len(self.img_right_depth_deque) >= 2000:
|
||||||
|
self.img_right_depth_deque.popleft()
|
||||||
|
self.img_right_depth_deque.append(msg)
|
||||||
|
|
||||||
|
def img_front_depth_callback(self, msg):
|
||||||
|
if len(self.img_front_depth_deque) >= 2000:
|
||||||
|
self.img_front_depth_deque.popleft()
|
||||||
|
self.img_front_depth_deque.append(msg)
|
||||||
|
|
||||||
|
def master_arm_left_callback(self, msg):
|
||||||
|
if len(self.master_arm_left_deque) >= 2000:
|
||||||
|
self.master_arm_left_deque.popleft()
|
||||||
|
self.master_arm_left_deque.append(msg)
|
||||||
|
|
||||||
|
def master_arm_right_callback(self, msg):
|
||||||
|
if len(self.master_arm_right_deque) >= 2000:
|
||||||
|
self.master_arm_right_deque.popleft()
|
||||||
|
self.master_arm_right_deque.append(msg)
|
||||||
|
|
||||||
|
def puppet_arm_left_callback(self, msg):
|
||||||
|
if len(self.puppet_arm_left_deque) >= 2000:
|
||||||
|
self.puppet_arm_left_deque.popleft()
|
||||||
|
self.puppet_arm_left_deque.append(msg)
|
||||||
|
|
||||||
|
def puppet_arm_right_callback(self, msg):
|
||||||
|
if len(self.puppet_arm_right_deque) >= 2000:
|
||||||
|
self.puppet_arm_right_deque.popleft()
|
||||||
|
self.puppet_arm_right_deque.append(msg)
|
||||||
|
|
||||||
|
def robot_base_callback(self, msg):
|
||||||
|
if len(self.robot_base_deque) >= 2000:
|
||||||
|
self.robot_base_deque.popleft()
|
||||||
|
self.robot_base_deque.append(msg)
|
||||||
|
|
||||||
|
def init_ros(self):
|
||||||
|
rospy.init_node('record_episodes', anonymous=True)
|
||||||
|
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
|
||||||
|
rospy.Subscriber(self.args.master_arm_left_topic, JointState, self.master_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.master_arm_right_topic, JointState, self.master_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
timesteps = []
|
||||||
|
actions = []
|
||||||
|
# 图像数据
|
||||||
|
image = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8)
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in self.args.camera_names:
|
||||||
|
image_dict[cam_name] = image
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
# input_key = input("please input s:")
|
||||||
|
# while input_key != 's' and not rospy.is_shutdown():
|
||||||
|
# input_key = input("please input s:")
|
||||||
|
|
||||||
|
rate = rospy.Rate(self.args.frame_rate)
|
||||||
|
print_flag = True
|
||||||
|
|
||||||
|
while (count < self.args.max_timesteps + 1) and not rospy.is_shutdown():
|
||||||
|
# 2 收集数据
|
||||||
|
result = self.get_frame()
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
if not result:
|
||||||
|
if print_flag:
|
||||||
|
print("syn fail\n")
|
||||||
|
print_flag = False
|
||||||
|
rate.sleep()
|
||||||
|
continue
|
||||||
|
print_flag = True
|
||||||
|
count += 1
|
||||||
|
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||||
|
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) = result
|
||||||
|
# 2.1 图像信息
|
||||||
|
image_dict = dict()
|
||||||
|
image_dict[self.args.camera_names[0]] = img_front
|
||||||
|
image_dict[self.args.camera_names[1]] = img_left
|
||||||
|
image_dict[self.args.camera_names[2]] = img_right
|
||||||
|
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
display_camera_grid(image_dict)
|
||||||
|
|
||||||
|
# 2.2 从臂的信息从臂的状态 机械臂示教模式时 会自动订阅
|
||||||
|
obs = collections.OrderedDict() # 有序的字典
|
||||||
|
obs['images'] = image_dict
|
||||||
|
if self.args.use_depth_image:
|
||||||
|
image_dict_depth = dict()
|
||||||
|
image_dict_depth[self.args.camera_names[0]] = img_front_depth
|
||||||
|
image_dict_depth[self.args.camera_names[1]] = img_left_depth
|
||||||
|
image_dict_depth[self.args.camera_names[2]] = img_right_depth
|
||||||
|
obs['images_depth'] = image_dict_depth
|
||||||
|
obs['qpos'] = np.concatenate((np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
|
||||||
|
obs['qvel'] = np.concatenate((np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
|
||||||
|
obs['effort'] = np.concatenate((np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
|
||||||
|
if self.args.use_robot_base:
|
||||||
|
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
|
||||||
|
else:
|
||||||
|
obs['base_vel'] = [0.0, 0.0]
|
||||||
|
|
||||||
|
# 第一帧 只包含first, fisrt只保存StepType.FIRST
|
||||||
|
if count == 1:
|
||||||
|
ts = dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.FIRST,
|
||||||
|
reward=None,
|
||||||
|
discount=None,
|
||||||
|
observation=obs)
|
||||||
|
timesteps.append(ts)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 时间步
|
||||||
|
ts = dm_env.TimeStep(
|
||||||
|
step_type=dm_env.StepType.MID,
|
||||||
|
reward=None,
|
||||||
|
discount=None,
|
||||||
|
observation=obs)
|
||||||
|
|
||||||
|
# 主臂保存状态
|
||||||
|
action = np.concatenate((np.array(master_arm_left.position), np.array(master_arm_right.position)), axis=0)
|
||||||
|
actions.append(action)
|
||||||
|
timesteps.append(ts)
|
||||||
|
print(f"\n{self.args.episode_idx} | Frame data: {count}\r", end="")
|
||||||
|
if rospy.is_shutdown():
|
||||||
|
exit(-1)
|
||||||
|
rate.sleep()
|
||||||
|
|
||||||
|
print("len(timesteps): ", len(timesteps))
|
||||||
|
print("len(actions) : ", len(actions))
|
||||||
|
return timesteps, actions
|
||||||
114
collect_data/collect_data.py
Executable file
114
collect_data/collect_data.py
Executable file
@@ -0,0 +1,114 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from aloha_mobile import AlohaRobotRos
|
||||||
|
from utils import save_data, init_keyboard_listener, load_config, log_say
|
||||||
|
|
||||||
|
|
||||||
|
def main(config_path):
|
||||||
|
args = load_config(config_path)
|
||||||
|
|
||||||
|
ros_operator = AlohaRobotRos(args)
|
||||||
|
dataset_dir = os.path.join(args.dataset_dir, args.task_name)
|
||||||
|
# Ensure dataset directory exists
|
||||||
|
os.makedirs(dataset_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Single episode collection mode
|
||||||
|
if args.num_episodes == 1:
|
||||||
|
print(f"Recording single episode {args.episode_idx}...")
|
||||||
|
timesteps, actions = ros_operator.process()
|
||||||
|
|
||||||
|
if len(actions) < args.max_timesteps:
|
||||||
|
print(f"\033[31m\nSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m\n")
|
||||||
|
return -1
|
||||||
|
|
||||||
|
dataset_path = os.path.join(dataset_dir, f"episode_{args.episode_idx}")
|
||||||
|
save_data(args, timesteps, actions, dataset_path)
|
||||||
|
print(f"\033[32mEpisode {args.episode_idx} saved successfully at {dataset_path}\033[0m")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Multi-episode collection mode
|
||||||
|
print("""
|
||||||
|
\033[1;36mKeyboard Controls:\033[0m
|
||||||
|
← \033[1mLeft Arrow\033[0m: Start Recording
|
||||||
|
→ \033[1mRight Arrow\033[0m: Save Current Data
|
||||||
|
↓ \033[1mDown Arrow\033[0m: Discard Current Data
|
||||||
|
↑ \033[1mUp Arrow\033[0m: Replay Data (if implemented)
|
||||||
|
\033[1mESC\033[0m: Exit Program
|
||||||
|
""")
|
||||||
|
log_say("欢迎您为 具身智能科学家项目采集数据,您辛苦了。我已经将一切准备就绪,请您按方向左键开始录制数据。", play_sounds=True)
|
||||||
|
|
||||||
|
listener, events = init_keyboard_listener()
|
||||||
|
episode_idx = args.episode_idx
|
||||||
|
collected_episodes = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
while collected_episodes < args.num_episodes:
|
||||||
|
if events["exit_early"]:
|
||||||
|
print("\033[33mOperation terminated by user\033[0m")
|
||||||
|
log_say("操作被你停止了,如果这是个误操作,请重新开始。", play_sounds=True)
|
||||||
|
break
|
||||||
|
|
||||||
|
if events["record_start"]:
|
||||||
|
# Reset event states for new recording
|
||||||
|
events["record_start"] = False
|
||||||
|
events["save_data"] = False
|
||||||
|
events["discard_data"] = False
|
||||||
|
log_say(f"开始录制第{episode_idx}条轨迹,请开始操作机械臂。", play_sounds=True)
|
||||||
|
print(f"\n\033[1;32mRecording episode {episode_idx}...\033[0m")
|
||||||
|
timesteps, actions = ros_operator.process()
|
||||||
|
print(f"\033[1;33mRecorded {len(actions)} timesteps. (→ to save, ↓ to discard)\033[0m")
|
||||||
|
log_say(f"第{episode_idx}条轨迹的录制已经到达最大时间步并录制结束,请选择是否保留该条轨迹。按方向右键保留,方向下键丢弃。", play_sounds=True)
|
||||||
|
|
||||||
|
# Wait for user decision to save or discard
|
||||||
|
while True:
|
||||||
|
if events["save_data"]:
|
||||||
|
events["save_data"] = False
|
||||||
|
|
||||||
|
if len(actions) < args.max_timesteps:
|
||||||
|
print(f"\033[31mSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m")
|
||||||
|
log_say(f"由于当前轨迹的实际时间步数小于最大时间步数,因此无法保存该条轨迹。该条轨迹将被丢弃,请重新录制。", play_sounds=True)
|
||||||
|
else:
|
||||||
|
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
|
||||||
|
log_say(f"你选择了保留该轨迹作为第{episode_idx}条轨迹数据。接下来请你按方向左键开始录制下一条轨迹。", play_sounds=True)
|
||||||
|
save_data(args, timesteps, actions, dataset_path)
|
||||||
|
print(f"\033[32mEpisode {episode_idx} saved successfully at {dataset_path}\033[0m")
|
||||||
|
episode_idx += 1
|
||||||
|
collected_episodes += 1
|
||||||
|
print(f"\033[1mProgress: {collected_episodes}/{args.num_episodes} episodes collected. (← to start new episode)\033[0m")
|
||||||
|
break
|
||||||
|
|
||||||
|
if events["discard_data"]:
|
||||||
|
events["discard_data"] = False
|
||||||
|
log_say(f"你选择了丢弃该轨迹作为第{episode_idx}条轨迹数据。接下来请你按方向左键开始录制下一条轨迹。", play_sounds=True)
|
||||||
|
print("\033[33mData discarded. Press ← to start a new recording.\033[0m")
|
||||||
|
break
|
||||||
|
|
||||||
|
if events["exit_early"]:
|
||||||
|
print("\033[33mOperation terminated by user\033[0m")
|
||||||
|
log_say("操作被你停止了,如果这是个误操作,请重新开始。", play_sounds=True)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
time.sleep(0.1) # Reduce CPU usage
|
||||||
|
|
||||||
|
time.sleep(0.1) # Reduce CPU usage
|
||||||
|
|
||||||
|
if collected_episodes == args.num_episodes:
|
||||||
|
log_say("恭喜你,本次数据已经全部录制完成。您辛苦了~", play_sounds=True)
|
||||||
|
print(f"\n\033[1;32mData collection complete! All {args.num_episodes} episodes collected.\033[0m")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Ensure listener is cleaned up
|
||||||
|
if listener:
|
||||||
|
listener.stop()
|
||||||
|
print("Keyboard listener stopped")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
try:
|
||||||
|
exit_code = main("/home/ubuntu/LYT/lerobot_aloha/collect_data/aloha.yaml")
|
||||||
|
exit(exit_code if exit_code is not None else 0)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\033[33mProgram interrupted by user\033[0m")
|
||||||
|
exit(0)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n\033[31mError: {e}\033[0m")
|
||||||
403
collect_data/collect_data_gui.py
Normal file
403
collect_data/collect_data_gui.py
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from collect_data import main
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
||||||
|
QLabel, QLineEdit, QPushButton, QSpinBox, QCheckBox,
|
||||||
|
QGroupBox, QTabWidget, QMessageBox, QFileDialog)
|
||||||
|
from PyQt5.QtCore import Qt
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AlohaDataCollectionGUI(QMainWindow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.setWindowTitle("MindRobot-V1 Data Collection")
|
||||||
|
self.setGeometry(100, 100, 800, 600)
|
||||||
|
|
||||||
|
self.central_widget = QWidget()
|
||||||
|
self.setCentralWidget(self.central_widget)
|
||||||
|
|
||||||
|
self.main_layout = QVBoxLayout(self.central_widget)
|
||||||
|
|
||||||
|
self.config_path = os.path.expanduser("/home/ubuntu/LYT/lerobot_aloha/collect_data/aloha.yaml")
|
||||||
|
self.create_ui()
|
||||||
|
self.setup_connections()
|
||||||
|
self.load_default_config()
|
||||||
|
|
||||||
|
def create_ui(self):
|
||||||
|
# Create tabs
|
||||||
|
self.tabs = QTabWidget()
|
||||||
|
self.main_layout.addWidget(self.tabs)
|
||||||
|
|
||||||
|
# General Settings Tab
|
||||||
|
self.general_tab = QWidget()
|
||||||
|
self.tabs.addTab(self.general_tab, "General Settings")
|
||||||
|
self.create_general_tab()
|
||||||
|
|
||||||
|
# Camera Settings Tab
|
||||||
|
self.camera_tab = QWidget()
|
||||||
|
self.tabs.addTab(self.camera_tab, "Camera Settings")
|
||||||
|
self.create_camera_tab()
|
||||||
|
|
||||||
|
# Arm Settings Tab
|
||||||
|
self.arm_tab = QWidget()
|
||||||
|
self.tabs.addTab(self.arm_tab, "Arm Settings")
|
||||||
|
self.create_arm_tab()
|
||||||
|
|
||||||
|
# Control Buttons
|
||||||
|
self.control_group = QGroupBox("Control")
|
||||||
|
control_layout = QHBoxLayout()
|
||||||
|
|
||||||
|
self.load_config_button = QPushButton("Load Config")
|
||||||
|
self.save_config_button = QPushButton("Save Config")
|
||||||
|
self.start_button = QPushButton("Start Recording")
|
||||||
|
self.stop_button = QPushButton("Stop Recording")
|
||||||
|
self.exit_button = QPushButton("Exit")
|
||||||
|
|
||||||
|
control_layout.addWidget(self.load_config_button)
|
||||||
|
control_layout.addWidget(self.save_config_button)
|
||||||
|
control_layout.addWidget(self.start_button)
|
||||||
|
control_layout.addWidget(self.stop_button)
|
||||||
|
control_layout.addWidget(self.exit_button)
|
||||||
|
|
||||||
|
self.control_group.setLayout(control_layout)
|
||||||
|
self.main_layout.addWidget(self.control_group)
|
||||||
|
|
||||||
|
def create_general_tab(self):
|
||||||
|
layout = QVBoxLayout(self.general_tab)
|
||||||
|
|
||||||
|
# Config File Path
|
||||||
|
config_group = QGroupBox("Configuration File")
|
||||||
|
config_layout = QHBoxLayout()
|
||||||
|
|
||||||
|
self.config_path_edit = QLineEdit(self.config_path)
|
||||||
|
self.browse_config_button = QPushButton("Browse...")
|
||||||
|
|
||||||
|
config_layout.addWidget(QLabel("Config File:"))
|
||||||
|
config_layout.addWidget(self.config_path_edit)
|
||||||
|
config_layout.addWidget(self.browse_config_button)
|
||||||
|
|
||||||
|
config_group.setLayout(config_layout)
|
||||||
|
layout.addWidget(config_group)
|
||||||
|
|
||||||
|
# Dataset Directory
|
||||||
|
dir_group = QGroupBox("Dataset Directory")
|
||||||
|
dir_layout = QHBoxLayout()
|
||||||
|
|
||||||
|
self.dataset_dir_edit = QLineEdit()
|
||||||
|
self.browse_dir_button = QPushButton("Browse...")
|
||||||
|
|
||||||
|
dir_layout.addWidget(QLabel("Dataset Directory:"))
|
||||||
|
dir_layout.addWidget(self.dataset_dir_edit)
|
||||||
|
dir_layout.addWidget(self.browse_dir_button)
|
||||||
|
|
||||||
|
dir_group.setLayout(dir_layout)
|
||||||
|
layout.addWidget(dir_group)
|
||||||
|
|
||||||
|
# Task Settings
|
||||||
|
task_group = QGroupBox("Task Settings")
|
||||||
|
task_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.task_name_edit = QLineEdit()
|
||||||
|
self.episode_idx_spin = QSpinBox()
|
||||||
|
self.episode_idx_spin.setRange(0, 9999)
|
||||||
|
self.max_timesteps_spin = QSpinBox()
|
||||||
|
self.max_timesteps_spin.setRange(1, 10000)
|
||||||
|
self.num_episodes_spin = QSpinBox()
|
||||||
|
self.num_episodes_spin.setRange(1, 1000)
|
||||||
|
self.frame_rate_spin = QSpinBox()
|
||||||
|
self.frame_rate_spin.setRange(1, 60)
|
||||||
|
|
||||||
|
task_layout.addWidget(QLabel("Task Name:"))
|
||||||
|
task_layout.addWidget(self.task_name_edit)
|
||||||
|
task_layout.addWidget(QLabel("Episode Index:"))
|
||||||
|
task_layout.addWidget(self.episode_idx_spin)
|
||||||
|
task_layout.addWidget(QLabel("Max Timesteps:"))
|
||||||
|
task_layout.addWidget(self.max_timesteps_spin)
|
||||||
|
task_layout.addWidget(QLabel("Number of Episodes:"))
|
||||||
|
task_layout.addWidget(self.num_episodes_spin)
|
||||||
|
task_layout.addWidget(QLabel("Frame Rate:"))
|
||||||
|
task_layout.addWidget(self.frame_rate_spin)
|
||||||
|
|
||||||
|
task_group.setLayout(task_layout)
|
||||||
|
layout.addWidget(task_group)
|
||||||
|
|
||||||
|
# Options
|
||||||
|
options_group = QGroupBox("Options")
|
||||||
|
options_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.use_robot_base_check = QCheckBox("Use Robot Base")
|
||||||
|
self.use_depth_image_check = QCheckBox("Use Depth Image")
|
||||||
|
|
||||||
|
options_layout.addWidget(self.use_robot_base_check)
|
||||||
|
options_layout.addWidget(self.use_depth_image_check)
|
||||||
|
|
||||||
|
options_group.setLayout(options_layout)
|
||||||
|
layout.addWidget(options_group)
|
||||||
|
|
||||||
|
layout.addStretch()
|
||||||
|
|
||||||
|
def create_camera_tab(self):
|
||||||
|
layout = QVBoxLayout(self.camera_tab)
|
||||||
|
|
||||||
|
# Color Image Topics
|
||||||
|
color_group = QGroupBox("Color Image Topics")
|
||||||
|
color_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.img_front_topic_edit = QLineEdit()
|
||||||
|
self.img_left_topic_edit = QLineEdit()
|
||||||
|
self.img_right_topic_edit = QLineEdit()
|
||||||
|
|
||||||
|
color_layout.addWidget(QLabel("Front Camera Topic:"))
|
||||||
|
color_layout.addWidget(self.img_front_topic_edit)
|
||||||
|
color_layout.addWidget(QLabel("Left Camera Topic:"))
|
||||||
|
color_layout.addWidget(self.img_left_topic_edit)
|
||||||
|
color_layout.addWidget(QLabel("Right Camera Topic:"))
|
||||||
|
color_layout.addWidget(self.img_right_topic_edit)
|
||||||
|
|
||||||
|
color_group.setLayout(color_layout)
|
||||||
|
layout.addWidget(color_group)
|
||||||
|
|
||||||
|
# Depth Image Topics
|
||||||
|
depth_group = QGroupBox("Depth Image Topics")
|
||||||
|
depth_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.img_front_depth_topic_edit = QLineEdit()
|
||||||
|
self.img_left_depth_topic_edit = QLineEdit()
|
||||||
|
self.img_right_depth_topic_edit = QLineEdit()
|
||||||
|
|
||||||
|
depth_layout.addWidget(QLabel("Front Depth Topic:"))
|
||||||
|
depth_layout.addWidget(self.img_front_depth_topic_edit)
|
||||||
|
depth_layout.addWidget(QLabel("Left Depth Topic:"))
|
||||||
|
depth_layout.addWidget(self.img_left_depth_topic_edit)
|
||||||
|
depth_layout.addWidget(QLabel("Right Depth Topic:"))
|
||||||
|
depth_layout.addWidget(self.img_right_depth_topic_edit)
|
||||||
|
|
||||||
|
depth_group.setLayout(depth_layout)
|
||||||
|
layout.addWidget(depth_group)
|
||||||
|
|
||||||
|
layout.addStretch()
|
||||||
|
|
||||||
|
def create_arm_tab(self):
|
||||||
|
layout = QVBoxLayout(self.arm_tab)
|
||||||
|
|
||||||
|
# Master Arm Topics
|
||||||
|
master_group = QGroupBox("Master Arm Topics")
|
||||||
|
master_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.master_arm_left_topic_edit = QLineEdit()
|
||||||
|
self.master_arm_right_topic_edit = QLineEdit()
|
||||||
|
|
||||||
|
master_layout.addWidget(QLabel("Master Left Arm Topic:"))
|
||||||
|
master_layout.addWidget(self.master_arm_left_topic_edit)
|
||||||
|
master_layout.addWidget(QLabel("Master Right Arm Topic:"))
|
||||||
|
master_layout.addWidget(self.master_arm_right_topic_edit)
|
||||||
|
|
||||||
|
master_group.setLayout(master_layout)
|
||||||
|
layout.addWidget(master_group)
|
||||||
|
|
||||||
|
# Puppet Arm Topics
|
||||||
|
puppet_group = QGroupBox("Puppet Arm Topics")
|
||||||
|
puppet_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.puppet_arm_left_topic_edit = QLineEdit()
|
||||||
|
self.puppet_arm_right_topic_edit = QLineEdit()
|
||||||
|
|
||||||
|
puppet_layout.addWidget(QLabel("Puppet Left Arm Topic:"))
|
||||||
|
puppet_layout.addWidget(self.puppet_arm_left_topic_edit)
|
||||||
|
puppet_layout.addWidget(QLabel("Puppet Right Arm Topic:"))
|
||||||
|
puppet_layout.addWidget(self.puppet_arm_right_topic_edit)
|
||||||
|
|
||||||
|
puppet_group.setLayout(puppet_layout)
|
||||||
|
layout.addWidget(puppet_group)
|
||||||
|
|
||||||
|
# Robot Base Topic
|
||||||
|
base_group = QGroupBox("Robot Base Topic")
|
||||||
|
base_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.robot_base_topic_edit = QLineEdit()
|
||||||
|
|
||||||
|
base_layout.addWidget(QLabel("Robot Base Topic:"))
|
||||||
|
base_layout.addWidget(self.robot_base_topic_edit)
|
||||||
|
|
||||||
|
base_group.setLayout(base_layout)
|
||||||
|
layout.addWidget(base_group)
|
||||||
|
|
||||||
|
layout.addStretch()
|
||||||
|
|
||||||
|
def setup_connections(self):
|
||||||
|
self.load_config_button.clicked.connect(self.load_config)
|
||||||
|
self.save_config_button.clicked.connect(self.save_config)
|
||||||
|
self.browse_config_button.clicked.connect(self.browse_config_file)
|
||||||
|
self.browse_dir_button.clicked.connect(self.browse_dataset_dir)
|
||||||
|
self.start_button.clicked.connect(self.start_recording)
|
||||||
|
self.stop_button.clicked.connect(self.stop_recording)
|
||||||
|
self.exit_button.clicked.connect(self.close)
|
||||||
|
|
||||||
|
def load_default_config(self):
|
||||||
|
default_config = {
|
||||||
|
'dataset_dir': '/home/ubuntu/LYT/lerobot_aloha/datasets/3camera',
|
||||||
|
'task_name': 'aloha_mobile_dummy',
|
||||||
|
'episode_idx': 0,
|
||||||
|
'max_timesteps': 500,
|
||||||
|
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
|
||||||
|
'num_episodes': 50,
|
||||||
|
'img_front_topic': '/camera_f/color/image_raw',
|
||||||
|
'img_left_topic': '/camera_l/color/image_raw',
|
||||||
|
'img_right_topic': '/camera_r/color/image_raw',
|
||||||
|
'img_front_depth_topic': '/camera_f/depth/image_raw',
|
||||||
|
'img_left_depth_topic': '/camera_l/depth/image_raw',
|
||||||
|
'img_right_depth_topic': '/camera_r/depth/image_raw',
|
||||||
|
'master_arm_left_topic': '/master/joint_left',
|
||||||
|
'master_arm_right_topic': '/master/joint_right',
|
||||||
|
'puppet_arm_left_topic': '/puppet/joint_left',
|
||||||
|
'puppet_arm_right_topic': '/puppet/joint_right',
|
||||||
|
'robot_base_topic': '/odom',
|
||||||
|
'use_robot_base': False,
|
||||||
|
'use_depth_image': False,
|
||||||
|
'frame_rate': 30
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update UI with default values
|
||||||
|
self.update_ui_from_config(default_config)
|
||||||
|
|
||||||
|
def update_ui_from_config(self, config):
|
||||||
|
"""Update UI elements from a config dictionary"""
|
||||||
|
self.dataset_dir_edit.setText(config.get('dataset_dir', ''))
|
||||||
|
self.task_name_edit.setText(config.get('task_name', ''))
|
||||||
|
self.episode_idx_spin.setValue(config.get('episode_idx', 0))
|
||||||
|
self.max_timesteps_spin.setValue(config.get('max_timesteps', 500))
|
||||||
|
self.num_episodes_spin.setValue(config.get('num_episodes', 1))
|
||||||
|
self.frame_rate_spin.setValue(config.get('frame_rate', 30))
|
||||||
|
|
||||||
|
self.img_front_topic_edit.setText(config.get('img_front_topic', ''))
|
||||||
|
self.img_left_topic_edit.setText(config.get('img_left_topic', ''))
|
||||||
|
self.img_right_topic_edit.setText(config.get('img_right_topic', ''))
|
||||||
|
|
||||||
|
self.img_front_depth_topic_edit.setText(config.get('img_front_depth_topic', ''))
|
||||||
|
self.img_left_depth_topic_edit.setText(config.get('img_left_depth_topic', ''))
|
||||||
|
self.img_right_depth_topic_edit.setText(config.get('img_right_depth_topic', ''))
|
||||||
|
|
||||||
|
self.master_arm_left_topic_edit.setText(config.get('master_arm_left_topic', ''))
|
||||||
|
self.master_arm_right_topic_edit.setText(config.get('master_arm_right_topic', ''))
|
||||||
|
self.puppet_arm_left_topic_edit.setText(config.get('puppet_arm_left_topic', ''))
|
||||||
|
self.puppet_arm_right_topic_edit.setText(config.get('puppet_arm_right_topic', ''))
|
||||||
|
|
||||||
|
self.robot_base_topic_edit.setText(config.get('robot_base_topic', ''))
|
||||||
|
self.use_robot_base_check.setChecked(config.get('use_robot_base', False))
|
||||||
|
self.use_depth_image_check.setChecked(config.get('use_depth_image', False))
|
||||||
|
|
||||||
|
def get_config_from_ui(self):
|
||||||
|
"""Get current UI values as a config dictionary"""
|
||||||
|
config = {
|
||||||
|
'dataset_dir': self.dataset_dir_edit.text(),
|
||||||
|
'task_name': self.task_name_edit.text(),
|
||||||
|
'episode_idx': self.episode_idx_spin.value(),
|
||||||
|
'max_timesteps': self.max_timesteps_spin.value(),
|
||||||
|
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist'],
|
||||||
|
'num_episodes': self.num_episodes_spin.value(),
|
||||||
|
'img_front_topic': self.img_front_topic_edit.text(),
|
||||||
|
'img_left_topic': self.img_left_topic_edit.text(),
|
||||||
|
'img_right_topic': self.img_right_topic_edit.text(),
|
||||||
|
'img_front_depth_topic': self.img_front_depth_topic_edit.text(),
|
||||||
|
'img_left_depth_topic': self.img_left_depth_topic_edit.text(),
|
||||||
|
'img_right_depth_topic': self.img_right_depth_topic_edit.text(),
|
||||||
|
'master_arm_left_topic': self.master_arm_left_topic_edit.text(),
|
||||||
|
'master_arm_right_topic': self.master_arm_right_topic_edit.text(),
|
||||||
|
'puppet_arm_left_topic': self.puppet_arm_left_topic_edit.text(),
|
||||||
|
'puppet_arm_right_topic': self.puppet_arm_right_topic_edit.text(),
|
||||||
|
'robot_base_topic': self.robot_base_topic_edit.text(),
|
||||||
|
'use_robot_base': self.use_robot_base_check.isChecked(),
|
||||||
|
'use_depth_image': self.use_depth_image_check.isChecked(),
|
||||||
|
'frame_rate': self.frame_rate_spin.value()
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def browse_config_file(self):
|
||||||
|
file_path, _ = QFileDialog.getOpenFileName(
|
||||||
|
self, "Select Config File", "", "YAML Files (*.yaml *.yml)"
|
||||||
|
)
|
||||||
|
if file_path:
|
||||||
|
self.config_path_edit.setText(file_path)
|
||||||
|
self.load_config()
|
||||||
|
|
||||||
|
def browse_dataset_dir(self):
|
||||||
|
dir_path = QFileDialog.getExistingDirectory(
|
||||||
|
self, "Select Dataset Directory"
|
||||||
|
)
|
||||||
|
if dir_path:
|
||||||
|
self.dataset_dir_edit.setText(dir_path)
|
||||||
|
|
||||||
|
def load_config(self):
|
||||||
|
config_path = self.config_path_edit.text()
|
||||||
|
if not os.path.exists(config_path):
|
||||||
|
QMessageBox.warning(self, "Warning", f"Config file not found: {config_path}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
self.update_ui_from_config(config)
|
||||||
|
self.statusBar().showMessage(f"Config loaded from {config_path}", 3000)
|
||||||
|
except Exception as e:
|
||||||
|
QMessageBox.critical(self, "Error", f"Failed to load config: {str(e)}")
|
||||||
|
|
||||||
|
def save_config(self):
|
||||||
|
config_path = self.config_path_edit.text()
|
||||||
|
if not config_path:
|
||||||
|
QMessageBox.warning(self, "Warning", "Please specify a config file path")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
config = self.get_config_from_ui()
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
yaml.dump(config, f, default_flow_style=False)
|
||||||
|
self.statusBar().showMessage(f"Config saved to {config_path}", 3000)
|
||||||
|
except Exception as e:
|
||||||
|
QMessageBox.critical(self, "Error", f"Failed to save config: {str(e)}")
|
||||||
|
|
||||||
|
def start_recording(self):
|
||||||
|
try:
|
||||||
|
# Save current config to a temporary file
|
||||||
|
temp_config_path = "/tmp/aloha_temp_config.yaml"
|
||||||
|
config = self.get_config_from_ui()
|
||||||
|
|
||||||
|
# Validate inputs
|
||||||
|
if not config['dataset_dir']:
|
||||||
|
QMessageBox.warning(self, "Warning", "Dataset directory cannot be empty!")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not config['task_name']:
|
||||||
|
QMessageBox.warning(self, "Warning", "Task name cannot be empty!")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(temp_config_path, 'w') as f:
|
||||||
|
yaml.dump(config, f, default_flow_style=False)
|
||||||
|
|
||||||
|
self.statusBar().showMessage("Recording started...")
|
||||||
|
|
||||||
|
# Start recording with the temporary config file
|
||||||
|
exit_code = main(temp_config_path)
|
||||||
|
|
||||||
|
if exit_code == 0:
|
||||||
|
self.statusBar().showMessage("Recording completed successfully!", 5000)
|
||||||
|
else:
|
||||||
|
self.statusBar().showMessage("Recording completed with errors", 5000)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
QMessageBox.critical(self, "Error", f"An error occurred: {str(e)}")
|
||||||
|
self.statusBar().showMessage("Recording failed", 5000)
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
# In a real application, this would signal the recording thread to stop
|
||||||
|
self.statusBar().showMessage("Recording stopped", 5000)
|
||||||
|
QMessageBox.information(self, "Info", "Stop recording requested. This would stop the recording in a real implementation.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
window = AlohaDataCollectionGUI()
|
||||||
|
window.show()
|
||||||
|
sys.exit(app.exec_())
|
||||||
@@ -1,487 +0,0 @@
|
|||||||
import logging
|
|
||||||
import time
|
|
||||||
from dataclasses import asdict
|
|
||||||
from pprint import pformat
|
|
||||||
from pprint import pprint
|
|
||||||
|
|
||||||
# from safetensors.torch import load_file, save_file
|
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
from lerobot.common.policies.factory import make_policy
|
|
||||||
from lerobot.common.robot_devices.control_configs import (
|
|
||||||
CalibrateControlConfig,
|
|
||||||
ControlPipelineConfig,
|
|
||||||
RecordControlConfig,
|
|
||||||
RemoteRobotConfig,
|
|
||||||
ReplayControlConfig,
|
|
||||||
TeleoperateControlConfig,
|
|
||||||
)
|
|
||||||
from lerobot.common.robot_devices.control_utils import (
|
|
||||||
# init_keyboard_listener,
|
|
||||||
record_episode,
|
|
||||||
stop_recording,
|
|
||||||
is_headless
|
|
||||||
)
|
|
||||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
|
||||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
|
||||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
|
||||||
from lerobot.common.utils.utils import get_safe_torch_device
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from copy import copy
|
|
||||||
import torch
|
|
||||||
import rospy
|
|
||||||
import cv2
|
|
||||||
from lerobot.configs import parser
|
|
||||||
from agilex_robot import AgilexRobot
|
|
||||||
|
|
||||||
|
|
||||||
########################################################################################
|
|
||||||
# Control modes
|
|
||||||
########################################################################################
|
|
||||||
|
|
||||||
|
|
||||||
def predict_action(observation, policy, device, use_amp):
|
|
||||||
observation = copy(observation)
|
|
||||||
with (
|
|
||||||
torch.inference_mode(),
|
|
||||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
|
||||||
):
|
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
||||||
for name in observation:
|
|
||||||
if "image" in name:
|
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
|
||||||
observation[name] = observation[name].to(device)
|
|
||||||
|
|
||||||
# Compute the next action with the policy
|
|
||||||
# based on the current observation
|
|
||||||
action = policy.select_action(observation)
|
|
||||||
|
|
||||||
# Remove batch dimension
|
|
||||||
action = action.squeeze(0)
|
|
||||||
|
|
||||||
# Move to cpu, if not already the case
|
|
||||||
action = action.to("cpu")
|
|
||||||
|
|
||||||
return action
|
|
||||||
|
|
||||||
def control_loop(
|
|
||||||
robot,
|
|
||||||
control_time_s=None,
|
|
||||||
teleoperate=False,
|
|
||||||
display_cameras=False,
|
|
||||||
dataset: LeRobotDataset | None = None,
|
|
||||||
events=None,
|
|
||||||
policy = None,
|
|
||||||
fps: int | None = None,
|
|
||||||
single_task: str | None = None,
|
|
||||||
):
|
|
||||||
# TODO(rcadene): Add option to record logs
|
|
||||||
# if not robot.is_connected:
|
|
||||||
# robot.connect()
|
|
||||||
|
|
||||||
if events is None:
|
|
||||||
events = {"exit_early": False}
|
|
||||||
|
|
||||||
if control_time_s is None:
|
|
||||||
control_time_s = float("inf")
|
|
||||||
|
|
||||||
if dataset is not None and single_task is None:
|
|
||||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
|
||||||
|
|
||||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
|
||||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
|
||||||
|
|
||||||
timestamp = 0
|
|
||||||
start_episode_t = time.perf_counter()
|
|
||||||
rate = rospy.Rate(fps)
|
|
||||||
print_flag = True
|
|
||||||
while timestamp < control_time_s and not rospy.is_shutdown():
|
|
||||||
# print(timestamp < control_time_s)
|
|
||||||
# print(rospy.is_shutdown())
|
|
||||||
start_loop_t = time.perf_counter()
|
|
||||||
|
|
||||||
if teleoperate:
|
|
||||||
observation, action = robot.teleop_step()
|
|
||||||
if observation is None or action is None:
|
|
||||||
if print_flag:
|
|
||||||
print("sync data fail, retrying...\n")
|
|
||||||
print_flag = False
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# pass
|
|
||||||
observation = robot.capture_observation()
|
|
||||||
if policy is not None:
|
|
||||||
pred_action = predict_action(
|
|
||||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
|
||||||
)
|
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
|
||||||
# so action actually sent is saved in the dataset.
|
|
||||||
action = robot.send_action(pred_action)
|
|
||||||
action = {"action": action}
|
|
||||||
|
|
||||||
if dataset is not None:
|
|
||||||
frame = {**observation, **action, "task": single_task}
|
|
||||||
dataset.add_frame(frame)
|
|
||||||
|
|
||||||
# if display_cameras and not is_headless():
|
|
||||||
# image_keys = [key for key in observation if "image" in key]
|
|
||||||
# for key in image_keys:
|
|
||||||
# if "depth" in key:
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
|
||||||
|
|
||||||
# print(1)
|
|
||||||
# cv2.waitKey(1)
|
|
||||||
|
|
||||||
if display_cameras and not is_headless():
|
|
||||||
image_keys = [key for key in observation if "image" in key]
|
|
||||||
|
|
||||||
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整)
|
|
||||||
screen_width = 1920
|
|
||||||
screen_height = 1080
|
|
||||||
|
|
||||||
# 计算窗口的排列方式
|
|
||||||
num_images = len(image_keys)
|
|
||||||
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
|
|
||||||
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
|
|
||||||
columns = min(num_images, max_columns) # 实际使用的列数
|
|
||||||
|
|
||||||
# 遍历所有图像键并显示
|
|
||||||
for idx, key in enumerate(image_keys):
|
|
||||||
if "depth" in key:
|
|
||||||
continue # 跳过深度图像
|
|
||||||
|
|
||||||
# 将图像从 RGB 转换为 BGR 格式
|
|
||||||
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
|
||||||
|
|
||||||
# 创建窗口
|
|
||||||
cv2.imshow(key, image)
|
|
||||||
|
|
||||||
# 计算窗口位置
|
|
||||||
window_width = 640
|
|
||||||
window_height = 480
|
|
||||||
row = idx // max_columns
|
|
||||||
col = idx % max_columns
|
|
||||||
x_position = col * window_width
|
|
||||||
y_position = row * window_height
|
|
||||||
|
|
||||||
# 移动窗口到指定位置
|
|
||||||
cv2.moveWindow(key, x_position, y_position)
|
|
||||||
|
|
||||||
# 等待 1 毫秒以处理事件
|
|
||||||
cv2.waitKey(1)
|
|
||||||
|
|
||||||
if fps is not None:
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
busy_wait(1 / fps - dt_s)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_loop_t
|
|
||||||
# log_control_info(robot, dt_s, fps=fps)
|
|
||||||
|
|
||||||
timestamp = time.perf_counter() - start_episode_t
|
|
||||||
if events["exit_early"]:
|
|
||||||
events["exit_early"] = False
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def init_keyboard_listener():
|
|
||||||
# Allow to exit early while recording an episode or resetting the environment,
|
|
||||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
|
||||||
# to allow your terminal to monitor keyboard events.
|
|
||||||
events = {}
|
|
||||||
events["exit_early"] = False
|
|
||||||
events["record_start"] = False
|
|
||||||
events["rerecord_episode"] = False
|
|
||||||
events["stop_recording"] = False
|
|
||||||
|
|
||||||
if is_headless():
|
|
||||||
logging.warning(
|
|
||||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
|
||||||
)
|
|
||||||
listener = None
|
|
||||||
return listener, events
|
|
||||||
|
|
||||||
# Only import pynput if not in a headless environment
|
|
||||||
from pynput import keyboard
|
|
||||||
|
|
||||||
def on_press(key):
|
|
||||||
try:
|
|
||||||
if key == keyboard.Key.right:
|
|
||||||
print("Right arrow key pressed. Exiting loop...")
|
|
||||||
events["exit_early"] = True
|
|
||||||
events["record_start"] = False
|
|
||||||
elif key == keyboard.Key.left:
|
|
||||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
|
||||||
events["rerecord_episode"] = True
|
|
||||||
events["exit_early"] = True
|
|
||||||
elif key == keyboard.Key.esc:
|
|
||||||
print("Escape key pressed. Stopping data recording...")
|
|
||||||
events["stop_recording"] = True
|
|
||||||
events["exit_early"] = True
|
|
||||||
elif key == keyboard.Key.up:
|
|
||||||
print("Up arrow pressed. Start data recording...")
|
|
||||||
events["record_start"] = True
|
|
||||||
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error handling key press: {e}")
|
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
|
||||||
listener.start()
|
|
||||||
|
|
||||||
return listener, events
|
|
||||||
|
|
||||||
|
|
||||||
def stop_recording(robot, listener, display_cameras):
|
|
||||||
|
|
||||||
if not is_headless():
|
|
||||||
if listener is not None:
|
|
||||||
listener.stop()
|
|
||||||
|
|
||||||
if display_cameras:
|
|
||||||
cv2.destroyAllWindows()
|
|
||||||
|
|
||||||
|
|
||||||
def record_episode(
|
|
||||||
robot,
|
|
||||||
dataset,
|
|
||||||
events,
|
|
||||||
episode_time_s,
|
|
||||||
display_cameras,
|
|
||||||
policy,
|
|
||||||
fps,
|
|
||||||
single_task,
|
|
||||||
):
|
|
||||||
control_loop(
|
|
||||||
robot=robot,
|
|
||||||
control_time_s=episode_time_s,
|
|
||||||
display_cameras=display_cameras,
|
|
||||||
dataset=dataset,
|
|
||||||
events=events,
|
|
||||||
policy=policy,
|
|
||||||
fps=fps,
|
|
||||||
teleoperate=policy is None,
|
|
||||||
single_task=single_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def record(
|
|
||||||
robot,
|
|
||||||
cfg
|
|
||||||
) -> LeRobotDataset:
|
|
||||||
# TODO(rcadene): Add option to record logs
|
|
||||||
if cfg.resume:
|
|
||||||
dataset = LeRobotDataset(
|
|
||||||
cfg.repo_id,
|
|
||||||
root=cfg.root,
|
|
||||||
)
|
|
||||||
if len(robot.cameras) > 0:
|
|
||||||
dataset.start_image_writer(
|
|
||||||
num_processes=cfg.num_image_writer_processes,
|
|
||||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
|
||||||
)
|
|
||||||
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
|
||||||
else:
|
|
||||||
# Create empty dataset or load existing saved episodes
|
|
||||||
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
|
|
||||||
dataset = LeRobotDataset.create(
|
|
||||||
cfg.repo_id,
|
|
||||||
cfg.fps,
|
|
||||||
root=cfg.root,
|
|
||||||
robot=None,
|
|
||||||
features=robot.features,
|
|
||||||
use_videos=cfg.video,
|
|
||||||
image_writer_processes=cfg.num_image_writer_processes,
|
|
||||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load pretrained policy
|
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
||||||
# policy = None
|
|
||||||
|
|
||||||
# if not robot.is_connected:
|
|
||||||
# robot.connect()
|
|
||||||
|
|
||||||
listener, events = init_keyboard_listener()
|
|
||||||
|
|
||||||
# Execute a few seconds without recording to:
|
|
||||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
|
||||||
# 2. give times to the robot devices to connect and start synchronizing,
|
|
||||||
# 3. place the cameras windows on screen
|
|
||||||
enable_teleoperation = policy is None
|
|
||||||
log_say("Warmup record", cfg.play_sounds)
|
|
||||||
print()
|
|
||||||
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
|
||||||
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
|
||||||
|
|
||||||
# if has_method(robot, "teleop_safety_stop"):
|
|
||||||
# robot.teleop_safety_stop()
|
|
||||||
|
|
||||||
recorded_episodes = 0
|
|
||||||
while True:
|
|
||||||
if recorded_episodes >= cfg.num_episodes:
|
|
||||||
break
|
|
||||||
|
|
||||||
# if events["record_start"]:
|
|
||||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
|
||||||
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
|
||||||
record_episode(
|
|
||||||
robot=robot,
|
|
||||||
dataset=dataset,
|
|
||||||
events=events,
|
|
||||||
episode_time_s=cfg.episode_time_s,
|
|
||||||
display_cameras=cfg.display_cameras,
|
|
||||||
policy=policy,
|
|
||||||
fps=cfg.fps,
|
|
||||||
single_task=cfg.single_task,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute a few seconds without recording to give time to manually reset the environment
|
|
||||||
# Current code logic doesn't allow to teleoperate during this time.
|
|
||||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
|
||||||
# Skip reset for the last episode to be recorded
|
|
||||||
if not events["stop_recording"] and (
|
|
||||||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
|
||||||
):
|
|
||||||
log_say("Reset the environment", cfg.play_sounds)
|
|
||||||
pprint("Reset the environment, stop recording")
|
|
||||||
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
|
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
|
||||||
log_say("Re-record episode", cfg.play_sounds)
|
|
||||||
pprint("Re-record episode")
|
|
||||||
events["rerecord_episode"] = False
|
|
||||||
events["exit_early"] = False
|
|
||||||
dataset.clear_episode_buffer()
|
|
||||||
continue
|
|
||||||
|
|
||||||
dataset.save_episode()
|
|
||||||
recorded_episodes += 1
|
|
||||||
|
|
||||||
if events["stop_recording"]:
|
|
||||||
break
|
|
||||||
|
|
||||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
|
||||||
stop_recording(robot, listener, cfg.display_cameras)
|
|
||||||
|
|
||||||
if cfg.push_to_hub:
|
|
||||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
|
||||||
|
|
||||||
log_say("Exiting", cfg.play_sounds)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def replay(
|
|
||||||
robot: AgilexRobot,
|
|
||||||
cfg,
|
|
||||||
):
|
|
||||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
|
||||||
# TODO(rcadene): Add option to record logs
|
|
||||||
|
|
||||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
|
||||||
actions = dataset.hf_dataset.select_columns("action")
|
|
||||||
|
|
||||||
# if not robot.is_connected:
|
|
||||||
# robot.connect()
|
|
||||||
|
|
||||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
|
||||||
for idx in range(dataset.num_frames):
|
|
||||||
start_episode_t = time.perf_counter()
|
|
||||||
|
|
||||||
action = actions[idx]["action"]
|
|
||||||
robot.send_action(action)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
|
||||||
busy_wait(1 / cfg.fps - dt_s)
|
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_episode_t
|
|
||||||
# log_control_info(robot, dt_s, fps=cfg.fps)
|
|
||||||
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
def get_arguments():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.fps = 30
|
|
||||||
args.resume = False
|
|
||||||
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
|
|
||||||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
|
|
||||||
args.episode = 0 # replay episode
|
|
||||||
args.num_image_writer_processes = 0
|
|
||||||
args.num_image_writer_threads_per_camera = 4
|
|
||||||
args.video = True
|
|
||||||
args.num_episodes = 100
|
|
||||||
args.episode_time_s = 30000
|
|
||||||
args.play_sounds = False
|
|
||||||
args.display_cameras = True
|
|
||||||
args.single_task = "move the bottle from the right to the scale right"
|
|
||||||
args.use_depth_image = False
|
|
||||||
args.use_base = False
|
|
||||||
args.push_to_hub = False
|
|
||||||
args.policy = None
|
|
||||||
# args.teleoprate = True
|
|
||||||
args.control_type = "record"
|
|
||||||
# args.control_type = "replay"
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# @parser.wrap()
|
|
||||||
# def control_robot(cfg: ControlPipelineConfig):
|
|
||||||
# init_logging()
|
|
||||||
# logging.info(pformat(asdict(cfg)))
|
|
||||||
|
|
||||||
# # robot = make_robot_from_config(cfg.robot)
|
|
||||||
# from agilex_robot import AgilexRobot
|
|
||||||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
|
||||||
|
|
||||||
# if isinstance(cfg.control, RecordControlConfig):
|
|
||||||
# print(cfg.control)
|
|
||||||
# record(robot, cfg.control)
|
|
||||||
# elif isinstance(cfg.control, ReplayControlConfig):
|
|
||||||
# replay(robot, cfg.control)
|
|
||||||
|
|
||||||
# # if robot.is_connected:
|
|
||||||
# # # Disconnect manually to avoid a "Core dump" during process
|
|
||||||
# # # termination due to camera threads not properly exiting.
|
|
||||||
# # robot.disconnect()
|
|
||||||
|
|
||||||
|
|
||||||
# @parser.wrap()
|
|
||||||
def control_robot(cfg):
|
|
||||||
|
|
||||||
# robot = make_robot_from_config(cfg.robot)
|
|
||||||
from agilex_robot import AgilexRobot
|
|
||||||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
|
||||||
|
|
||||||
if cfg.control_type == "record":
|
|
||||||
record(robot, cfg)
|
|
||||||
elif cfg.control_type == "replay":
|
|
||||||
replay(robot, cfg)
|
|
||||||
|
|
||||||
# if robot.is_connected:
|
|
||||||
# # Disconnect manually to avoid a "Core dump" during process
|
|
||||||
# # termination due to camera threads not properly exiting.
|
|
||||||
# robot.disconnect()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cfg = get_arguments()
|
|
||||||
control_robot(cfg)
|
|
||||||
# control_robot()
|
|
||||||
# cfg = get_arguments()
|
|
||||||
# from agilex_robot import AgilexRobot
|
|
||||||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
|
||||||
# print(robot.features.items())
|
|
||||||
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
|
|
||||||
# record(robot, cfg)
|
|
||||||
# capture = robot.capture_observation()
|
|
||||||
# import torch
|
|
||||||
# torch.save(capture, "test.pt")
|
|
||||||
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
|
|
||||||
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
|
|
||||||
# device='cpu')
|
|
||||||
# robot.send_action(action.squeeze(0))
|
|
||||||
# print()
|
|
||||||
292
collect_data/convert_aloha_data_to_lerobot.py
Normal file
292
collect_data/convert_aloha_data_to_lerobot.py
Normal file
@@ -0,0 +1,292 @@
|
|||||||
|
"""
|
||||||
|
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.
|
||||||
|
|
||||||
|
Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
from typing import Literal
|
||||||
|
import os
|
||||||
|
import h5py
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.push_dataset_to_hub._download_raw import download_raw
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
import tyro
|
||||||
|
|
||||||
|
# 使用自定义路径覆盖
|
||||||
|
LEROBOT_HOME = Path("/home/ubuntu/hdd0/lerobot_datasets/3camera")
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class DatasetConfig:
|
||||||
|
use_videos: bool = True
|
||||||
|
tolerance_s: float = 0.0001
|
||||||
|
image_writer_processes: int = 10
|
||||||
|
image_writer_threads: int = 5
|
||||||
|
video_backend: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_DATASET_CONFIG = DatasetConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def create_empty_dataset(
|
||||||
|
repo_id: str,
|
||||||
|
robot_type: str,
|
||||||
|
mode: Literal["video", "image"] = "video",
|
||||||
|
*,
|
||||||
|
has_velocity: bool = False,
|
||||||
|
has_effort: bool = False,
|
||||||
|
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
motors = [
|
||||||
|
"right_waist",
|
||||||
|
"right_shoulder",
|
||||||
|
"right_elbow",
|
||||||
|
"right_forearm_roll",
|
||||||
|
"right_wrist_angle",
|
||||||
|
"right_wrist_rotate",
|
||||||
|
"right_gripper",
|
||||||
|
"left_waist",
|
||||||
|
"left_shoulder",
|
||||||
|
"left_elbow",
|
||||||
|
"left_forearm_roll",
|
||||||
|
"left_wrist_angle",
|
||||||
|
"left_wrist_rotate",
|
||||||
|
"left_gripper",
|
||||||
|
]
|
||||||
|
# 确定camera的情况
|
||||||
|
# cameras = [
|
||||||
|
# "cam_high",
|
||||||
|
# "cam_low",
|
||||||
|
# "cam_left_wrist",
|
||||||
|
# "cam_right_wrist",
|
||||||
|
# ]
|
||||||
|
cameras = [
|
||||||
|
"cam_high",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_right_wrist",
|
||||||
|
]
|
||||||
|
|
||||||
|
features = {
|
||||||
|
"observation.state": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors),),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors),),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if has_velocity:
|
||||||
|
features["observation.velocity"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors),),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if has_effort:
|
||||||
|
features["observation.effort"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(motors),),
|
||||||
|
"names": [
|
||||||
|
motors,
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
for cam in cameras:
|
||||||
|
features[f"observation.images.{cam}"] = {
|
||||||
|
"dtype": mode,
|
||||||
|
"shape": (3, 480, 640),
|
||||||
|
"names": [
|
||||||
|
"channels",
|
||||||
|
"height",
|
||||||
|
"width",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
if Path(LEROBOT_HOME / repo_id).exists():
|
||||||
|
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
|
return LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=30,
|
||||||
|
root=Path(LEROBOT_HOME / repo_id),
|
||||||
|
robot_type=robot_type,
|
||||||
|
features=features,
|
||||||
|
use_videos=dataset_config.use_videos,
|
||||||
|
tolerance_s=dataset_config.tolerance_s,
|
||||||
|
image_writer_processes=dataset_config.image_writer_processes,
|
||||||
|
image_writer_threads=dataset_config.image_writer_threads,
|
||||||
|
video_backend=dataset_config.video_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cameras(hdf5_files: list[Path]) -> list[str]:
|
||||||
|
with h5py.File(hdf5_files[0], "r") as ep:
|
||||||
|
# ignore depth channel, not currently handled
|
||||||
|
return [key for key in ep["/observations/images"].keys() if "depth" not in key] # noqa: SIM118
|
||||||
|
|
||||||
|
|
||||||
|
def has_velocity(hdf5_files: list[Path]) -> bool:
|
||||||
|
with h5py.File(hdf5_files[0], "r") as ep:
|
||||||
|
return "/observations/qvel" in ep
|
||||||
|
|
||||||
|
|
||||||
|
def has_effort(hdf5_files: list[Path]) -> bool:
|
||||||
|
with h5py.File(hdf5_files[0], "r") as ep:
|
||||||
|
return "/observations/effort" in ep
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
|
||||||
|
imgs_per_cam = {}
|
||||||
|
for camera in cameras:
|
||||||
|
uncompressed = ep[f"/observations/images/{camera}"].ndim == 4
|
||||||
|
|
||||||
|
if uncompressed:
|
||||||
|
# load all images in RAM
|
||||||
|
imgs_array = ep[f"/observations/images/{camera}"][:]
|
||||||
|
else:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# load one compressed image after the other in RAM and uncompress
|
||||||
|
imgs_array = []
|
||||||
|
for data in ep[f"/observations/images/{camera}"]:
|
||||||
|
imgs_array.append(cv2.cvtColor(cv2.imdecode(data, 1), cv2.COLOR_BGR2RGB))
|
||||||
|
imgs_array = np.array(imgs_array)
|
||||||
|
|
||||||
|
imgs_per_cam[camera] = imgs_array
|
||||||
|
return imgs_per_cam
|
||||||
|
|
||||||
|
|
||||||
|
def load_raw_episode_data(
|
||||||
|
ep_path: Path,
|
||||||
|
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||||
|
with h5py.File(ep_path, "r") as ep:
|
||||||
|
state = torch.from_numpy(ep["/observations/qpos"][:])
|
||||||
|
action = torch.from_numpy(ep["/action"][:])
|
||||||
|
|
||||||
|
velocity = None
|
||||||
|
if "/observations/qvel" in ep:
|
||||||
|
velocity = torch.from_numpy(ep["/observations/qvel"][:])
|
||||||
|
|
||||||
|
effort = None
|
||||||
|
if "/observations/effort" in ep:
|
||||||
|
effort = torch.from_numpy(ep["/observations/effort"][:])
|
||||||
|
|
||||||
|
# 确定camera的情况
|
||||||
|
# imgs_per_cam = load_raw_images_per_camera(
|
||||||
|
# ep,
|
||||||
|
# [
|
||||||
|
# "cam_high",
|
||||||
|
# "cam_low",
|
||||||
|
# "cam_left_wrist",
|
||||||
|
# "cam_right_wrist",
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
imgs_per_cam = load_raw_images_per_camera(
|
||||||
|
ep,
|
||||||
|
[
|
||||||
|
"cam_high",
|
||||||
|
"cam_left_wrist",
|
||||||
|
"cam_right_wrist",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return imgs_per_cam, state, action, velocity, effort
|
||||||
|
|
||||||
|
|
||||||
|
def populate_dataset(
|
||||||
|
dataset: LeRobotDataset,
|
||||||
|
hdf5_files: list[Path],
|
||||||
|
task: str,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
if episodes is None:
|
||||||
|
episodes = range(len(hdf5_files))
|
||||||
|
|
||||||
|
for ep_idx in tqdm.tqdm(episodes):
|
||||||
|
ep_path = hdf5_files[ep_idx]
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
imgs_per_cam, state, action, velocity, effort = load_raw_episode_data(ep_path)
|
||||||
|
num_frames = state.shape[0]
|
||||||
|
|
||||||
|
for i in range(num_frames):
|
||||||
|
frame = {
|
||||||
|
"observation.state": state[i],
|
||||||
|
"action": action[i],
|
||||||
|
}
|
||||||
|
|
||||||
|
for camera, img_array in imgs_per_cam.items():
|
||||||
|
frame[f"observation.images.{camera}"] = img_array[i]
|
||||||
|
|
||||||
|
if velocity is not None:
|
||||||
|
frame["observation.velocity"] = velocity[i]
|
||||||
|
if effort is not None:
|
||||||
|
frame["observation.effort"] = effort[i]
|
||||||
|
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
dataset.save_episode(task=task)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def port_aloha(
|
||||||
|
raw_dir: Path,
|
||||||
|
repo_id: str,
|
||||||
|
raw_repo_id: str | None = None,
|
||||||
|
task: str = "DEBUG",
|
||||||
|
*,
|
||||||
|
episodes: list[int] | None = None,
|
||||||
|
push_to_hub: bool = False,
|
||||||
|
is_mobile: bool = False,
|
||||||
|
mode: Literal["video", "image"] = "image",
|
||||||
|
dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
|
||||||
|
):
|
||||||
|
print(LEROBOT_HOME)
|
||||||
|
if (LEROBOT_HOME / repo_id).exists():
|
||||||
|
shutil.rmtree(LEROBOT_HOME / repo_id)
|
||||||
|
|
||||||
|
if not raw_dir.exists():
|
||||||
|
if raw_repo_id is None:
|
||||||
|
raise ValueError("raw_repo_id must be provided if raw_dir does not exist")
|
||||||
|
download_raw(raw_dir, repo_id=raw_repo_id)
|
||||||
|
|
||||||
|
hdf5_files = sorted(raw_dir.glob("episode_*.hdf5"))
|
||||||
|
|
||||||
|
dataset = create_empty_dataset(
|
||||||
|
repo_id,
|
||||||
|
robot_type="mobile_aloha" if is_mobile else "aloha",
|
||||||
|
mode=mode,
|
||||||
|
has_effort=has_effort(hdf5_files),
|
||||||
|
has_velocity=has_velocity(hdf5_files),
|
||||||
|
dataset_config=dataset_config,
|
||||||
|
)
|
||||||
|
dataset = populate_dataset(
|
||||||
|
dataset,
|
||||||
|
hdf5_files,
|
||||||
|
task=task,
|
||||||
|
episodes=episodes,
|
||||||
|
)
|
||||||
|
dataset.consolidate()
|
||||||
|
|
||||||
|
if push_to_hub:
|
||||||
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tyro.cli(port_aloha)
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
|
|
||||||
# export LD_LIBRARY_PATH=/home/ubuntu/miniconda3/envs/lerobot/lib:$LD_LIBRARY_PATH
|
|
||||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
|
||||||
@@ -1,769 +0,0 @@
|
|||||||
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
|
||||||
# -- coding: UTF-8
|
|
||||||
"""
|
|
||||||
#!/usr/bin/python3
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
import argparse
|
|
||||||
from einops import rearrange
|
|
||||||
import collections
|
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
import rospy
|
|
||||||
from std_msgs.msg import Header
|
|
||||||
from geometry_msgs.msg import Twist
|
|
||||||
from sensor_msgs.msg import JointState, Image
|
|
||||||
from nav_msgs.msg import Odometry
|
|
||||||
from cv_bridge import CvBridge
|
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
import math
|
|
||||||
import threading
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
import sys
|
|
||||||
sys.path.append("./")
|
|
||||||
|
|
||||||
SEED = 42
|
|
||||||
torch.manual_seed(SEED)
|
|
||||||
np.random.seed(SEED)
|
|
||||||
|
|
||||||
task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']}
|
|
||||||
|
|
||||||
inference_thread = None
|
|
||||||
inference_lock = threading.Lock()
|
|
||||||
inference_actions = None
|
|
||||||
inference_timestep = None
|
|
||||||
|
|
||||||
|
|
||||||
def actions_interpolation(args, pre_action, actions, stats):
|
|
||||||
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
|
||||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
|
||||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
|
||||||
result = [pre_action]
|
|
||||||
post_action = post_process(actions[0])
|
|
||||||
# print("pre_action:", pre_action[7:])
|
|
||||||
# print("actions_interpolation1:", post_action[:, 7:])
|
|
||||||
max_diff_index = 0
|
|
||||||
max_diff = -1
|
|
||||||
for i in range(post_action.shape[0]):
|
|
||||||
diff = 0
|
|
||||||
for j in range(pre_action.shape[0]):
|
|
||||||
if j == 6 or j == 13:
|
|
||||||
continue
|
|
||||||
diff += math.fabs(pre_action[j] - post_action[i][j])
|
|
||||||
if diff > max_diff:
|
|
||||||
max_diff = diff
|
|
||||||
max_diff_index = i
|
|
||||||
|
|
||||||
for i in range(max_diff_index, post_action.shape[0]):
|
|
||||||
step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])])
|
|
||||||
inter = np.linspace(result[-1], post_action[i], step+2)
|
|
||||||
result.extend(inter[1:])
|
|
||||||
while len(result) < args.chunk_size+1:
|
|
||||||
result.append(result[-1])
|
|
||||||
result = np.array(result)[1:args.chunk_size+1]
|
|
||||||
# print("actions_interpolation2:", result.shape, result[:, 7:])
|
|
||||||
result = pre_process(result)
|
|
||||||
result = result[np.newaxis, :]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(args):
|
|
||||||
# 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。
|
|
||||||
set_seed(1)
|
|
||||||
|
|
||||||
# 如果是ACT策略
|
|
||||||
# fixed parameters
|
|
||||||
if args.policy_class == 'ACT':
|
|
||||||
policy_config = {'lr': args.lr,
|
|
||||||
'lr_backbone': args.lr_backbone,
|
|
||||||
'backbone': args.backbone,
|
|
||||||
'masks': args.masks,
|
|
||||||
'weight_decay': args.weight_decay,
|
|
||||||
'dilation': args.dilation,
|
|
||||||
'position_embedding': args.position_embedding,
|
|
||||||
'loss_function': args.loss_function,
|
|
||||||
'chunk_size': args.chunk_size, # 查询
|
|
||||||
'camera_names': task_config['camera_names'],
|
|
||||||
'use_depth_image': args.use_depth_image,
|
|
||||||
'use_robot_base': args.use_robot_base,
|
|
||||||
'kl_weight': args.kl_weight, # kl散度权重
|
|
||||||
'hidden_dim': args.hidden_dim, # 隐藏层维度
|
|
||||||
'dim_feedforward': args.dim_feedforward,
|
|
||||||
'enc_layers': args.enc_layers,
|
|
||||||
'dec_layers': args.dec_layers,
|
|
||||||
'nheads': args.nheads,
|
|
||||||
'dropout': args.dropout,
|
|
||||||
'pre_norm': args.pre_norm
|
|
||||||
}
|
|
||||||
elif args.policy_class == 'CNNMLP':
|
|
||||||
policy_config = {'lr': args.lr,
|
|
||||||
'lr_backbone': args.lr_backbone,
|
|
||||||
'backbone': args.backbone,
|
|
||||||
'masks': args.masks,
|
|
||||||
'weight_decay': args.weight_decay,
|
|
||||||
'dilation': args.dilation,
|
|
||||||
'position_embedding': args.position_embedding,
|
|
||||||
'loss_function': args.loss_function,
|
|
||||||
'chunk_size': 1, # 查询
|
|
||||||
'camera_names': task_config['camera_names'],
|
|
||||||
'use_depth_image': args.use_depth_image,
|
|
||||||
'use_robot_base': args.use_robot_base
|
|
||||||
}
|
|
||||||
|
|
||||||
elif args.policy_class == 'Diffusion':
|
|
||||||
policy_config = {'lr': args.lr,
|
|
||||||
'lr_backbone': args.lr_backbone,
|
|
||||||
'backbone': args.backbone,
|
|
||||||
'masks': args.masks,
|
|
||||||
'weight_decay': args.weight_decay,
|
|
||||||
'dilation': args.dilation,
|
|
||||||
'position_embedding': args.position_embedding,
|
|
||||||
'loss_function': args.loss_function,
|
|
||||||
'chunk_size': args.chunk_size, # 查询
|
|
||||||
'camera_names': task_config['camera_names'],
|
|
||||||
'use_depth_image': args.use_depth_image,
|
|
||||||
'use_robot_base': args.use_robot_base,
|
|
||||||
'observation_horizon': args.observation_horizon,
|
|
||||||
'action_horizon': args.action_horizon,
|
|
||||||
'num_inference_timesteps': args.num_inference_timesteps,
|
|
||||||
'ema_power': args.ema_power
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
config = {
|
|
||||||
'ckpt_dir': args.ckpt_dir,
|
|
||||||
'ckpt_name': args.ckpt_name,
|
|
||||||
'ckpt_stats_name': args.ckpt_stats_name,
|
|
||||||
'episode_len': args.max_publish_step,
|
|
||||||
'state_dim': args.state_dim,
|
|
||||||
'policy_class': args.policy_class,
|
|
||||||
'policy_config': policy_config,
|
|
||||||
'temporal_agg': args.temporal_agg,
|
|
||||||
'camera_names': task_config['camera_names'],
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def make_policy(policy_class, policy_config):
|
|
||||||
if policy_class == 'ACT':
|
|
||||||
policy = ACTPolicy(policy_config)
|
|
||||||
elif policy_class == 'CNNMLP':
|
|
||||||
policy = CNNMLPPolicy(policy_config)
|
|
||||||
elif policy_class == 'Diffusion':
|
|
||||||
policy = DiffusionPolicy(policy_config)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return policy
|
|
||||||
|
|
||||||
|
|
||||||
def get_image(observation, camera_names):
|
|
||||||
curr_images = []
|
|
||||||
for cam_name in camera_names:
|
|
||||||
curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w')
|
|
||||||
|
|
||||||
curr_images.append(curr_image)
|
|
||||||
curr_image = np.stack(curr_images, axis=0)
|
|
||||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
|
||||||
return curr_image
|
|
||||||
|
|
||||||
|
|
||||||
def get_depth_image(observation, camera_names):
|
|
||||||
curr_images = []
|
|
||||||
for cam_name in camera_names:
|
|
||||||
curr_images.append(observation['images_depth'][cam_name])
|
|
||||||
curr_image = np.stack(curr_images, axis=0)
|
|
||||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
|
||||||
return curr_image
|
|
||||||
|
|
||||||
|
|
||||||
def inference_process(args, config, ros_operator, policy, stats, t, pre_action):
|
|
||||||
global inference_lock
|
|
||||||
global inference_actions
|
|
||||||
global inference_timestep
|
|
||||||
print_flag = True
|
|
||||||
pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
|
||||||
pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"]
|
|
||||||
rate = rospy.Rate(args.publish_rate)
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
result = ros_operator.get_frame()
|
|
||||||
if not result:
|
|
||||||
if print_flag:
|
|
||||||
print("syn fail")
|
|
||||||
print_flag = False
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
print_flag = True
|
|
||||||
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
|
||||||
puppet_arm_left, puppet_arm_right, robot_base) = result
|
|
||||||
obs = collections.OrderedDict()
|
|
||||||
image_dict = dict()
|
|
||||||
|
|
||||||
image_dict[config['camera_names'][0]] = img_front
|
|
||||||
image_dict[config['camera_names'][1]] = img_left
|
|
||||||
image_dict[config['camera_names'][2]] = img_right
|
|
||||||
|
|
||||||
|
|
||||||
obs['images'] = image_dict
|
|
||||||
|
|
||||||
if args.use_depth_image:
|
|
||||||
image_depth_dict = dict()
|
|
||||||
image_depth_dict[config['camera_names'][0]] = img_front_depth
|
|
||||||
image_depth_dict[config['camera_names'][1]] = img_left_depth
|
|
||||||
image_depth_dict[config['camera_names'][2]] = img_right_depth
|
|
||||||
obs['images_depth'] = image_depth_dict
|
|
||||||
|
|
||||||
obs['qpos'] = np.concatenate(
|
|
||||||
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
|
|
||||||
obs['qvel'] = np.concatenate(
|
|
||||||
(np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
|
|
||||||
obs['effort'] = np.concatenate(
|
|
||||||
(np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
|
|
||||||
if args.use_robot_base:
|
|
||||||
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
|
|
||||||
obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0)
|
|
||||||
else:
|
|
||||||
obs['base_vel'] = [0.0, 0.0]
|
|
||||||
# qpos_numpy = np.array(obs['qpos'])
|
|
||||||
|
|
||||||
# 归一化处理qpos 并转到cuda
|
|
||||||
qpos = pre_pos_process(obs['qpos'])
|
|
||||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
|
||||||
# 当前图像curr_image获取图像
|
|
||||||
curr_image = get_image(obs, config['camera_names'])
|
|
||||||
curr_depth_image = None
|
|
||||||
if args.use_depth_image:
|
|
||||||
curr_depth_image = get_depth_image(obs, config['camera_names'])
|
|
||||||
start_time = time.time()
|
|
||||||
all_actions = policy(curr_image, curr_depth_image, qpos)
|
|
||||||
end_time = time.time()
|
|
||||||
print("model cost time: ", end_time -start_time)
|
|
||||||
inference_lock.acquire()
|
|
||||||
inference_actions = all_actions.cpu().detach().numpy()
|
|
||||||
if pre_action is None:
|
|
||||||
pre_action = obs['qpos']
|
|
||||||
# print("obs['qpos']:", obs['qpos'][7:])
|
|
||||||
if args.use_actions_interpolation:
|
|
||||||
inference_actions = actions_interpolation(args, pre_action, inference_actions, stats)
|
|
||||||
inference_timestep = t
|
|
||||||
inference_lock.release()
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def model_inference(args, config, ros_operator, save_episode=True):
|
|
||||||
global inference_lock
|
|
||||||
global inference_actions
|
|
||||||
global inference_timestep
|
|
||||||
global inference_thread
|
|
||||||
set_seed(1000)
|
|
||||||
|
|
||||||
# 1 创建模型数据 继承nn.Module
|
|
||||||
policy = make_policy(config['policy_class'], config['policy_config'])
|
|
||||||
# print("model structure\n", policy.model)
|
|
||||||
|
|
||||||
# 2 加载模型权重
|
|
||||||
ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name'])
|
|
||||||
state_dict = torch.load(ckpt_path)
|
|
||||||
new_state_dict = {}
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]:
|
|
||||||
continue
|
|
||||||
if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]:
|
|
||||||
continue
|
|
||||||
new_state_dict[key] = value
|
|
||||||
loading_status = policy.deserialize(new_state_dict)
|
|
||||||
if not loading_status:
|
|
||||||
print("ckpt path not exist")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 3 模型设置为cuda模式和验证模式
|
|
||||||
policy.cuda()
|
|
||||||
policy.eval()
|
|
||||||
|
|
||||||
# 4 加载统计值
|
|
||||||
stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name'])
|
|
||||||
# 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维
|
|
||||||
with open(stats_path, 'rb') as f:
|
|
||||||
stats = pickle.load(f)
|
|
||||||
|
|
||||||
# 数据预处理和后处理函数定义
|
|
||||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
|
||||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
|
||||||
|
|
||||||
max_publish_step = config['episode_len']
|
|
||||||
chunk_size = config['policy_config']['chunk_size']
|
|
||||||
|
|
||||||
# 发布基础的姿态
|
|
||||||
left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
|
|
||||||
right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
|
|
||||||
left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
|
|
||||||
right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
|
|
||||||
|
|
||||||
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
|
||||||
input("Enter any key to continue :")
|
|
||||||
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
|
||||||
action = None
|
|
||||||
# 推理
|
|
||||||
with torch.inference_mode():
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
# 每个回合的步数
|
|
||||||
t = 0
|
|
||||||
max_t = 0
|
|
||||||
rate = rospy.Rate(args.publish_rate)
|
|
||||||
if config['temporal_agg']:
|
|
||||||
all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']])
|
|
||||||
while t < max_publish_step and not rospy.is_shutdown():
|
|
||||||
# start_time = time.time()
|
|
||||||
# query policy
|
|
||||||
if config['policy_class'] == "ACT":
|
|
||||||
if t >= max_t:
|
|
||||||
pre_action = action
|
|
||||||
inference_thread = threading.Thread(target=inference_process,
|
|
||||||
args=(args, config, ros_operator,
|
|
||||||
policy, stats, t, pre_action))
|
|
||||||
inference_thread.start()
|
|
||||||
inference_thread.join()
|
|
||||||
inference_lock.acquire()
|
|
||||||
if inference_actions is not None:
|
|
||||||
inference_thread = None
|
|
||||||
all_actions = inference_actions
|
|
||||||
inference_actions = None
|
|
||||||
max_t = t + args.pos_lookahead_step
|
|
||||||
if config['temporal_agg']:
|
|
||||||
all_time_actions[[t], t:t + chunk_size] = all_actions
|
|
||||||
inference_lock.release()
|
|
||||||
if config['temporal_agg']:
|
|
||||||
actions_for_curr_step = all_time_actions[:, t]
|
|
||||||
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
|
|
||||||
actions_for_curr_step = actions_for_curr_step[actions_populated]
|
|
||||||
k = 0.01
|
|
||||||
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
|
||||||
exp_weights = exp_weights / exp_weights.sum()
|
|
||||||
exp_weights = exp_weights[:, np.newaxis]
|
|
||||||
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True)
|
|
||||||
else:
|
|
||||||
if args.pos_lookahead_step != 0:
|
|
||||||
raw_action = all_actions[:, t % args.pos_lookahead_step]
|
|
||||||
else:
|
|
||||||
raw_action = all_actions[:, t % chunk_size]
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
action = post_process(raw_action[0])
|
|
||||||
left_action = action[:7] # 取7维度
|
|
||||||
right_action = action[7:14]
|
|
||||||
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
|
|
||||||
if args.use_robot_base:
|
|
||||||
vel_action = action[14:16]
|
|
||||||
ros_operator.robot_base_publish(vel_action)
|
|
||||||
t += 1
|
|
||||||
# end_time = time.time()
|
|
||||||
# print("publish: ", t)
|
|
||||||
# print("time:", end_time - start_time)
|
|
||||||
# print("left_action:", left_action)
|
|
||||||
# print("right_action:", right_action)
|
|
||||||
rate.sleep()
|
|
||||||
|
|
||||||
|
|
||||||
class RosOperator:
|
|
||||||
def __init__(self, args):
|
|
||||||
self.robot_base_deque = None
|
|
||||||
self.puppet_arm_right_deque = None
|
|
||||||
self.puppet_arm_left_deque = None
|
|
||||||
self.img_front_deque = None
|
|
||||||
self.img_right_deque = None
|
|
||||||
self.img_left_deque = None
|
|
||||||
self.img_front_depth_deque = None
|
|
||||||
self.img_right_depth_deque = None
|
|
||||||
self.img_left_depth_deque = None
|
|
||||||
self.bridge = None
|
|
||||||
self.puppet_arm_left_publisher = None
|
|
||||||
self.puppet_arm_right_publisher = None
|
|
||||||
self.robot_base_publisher = None
|
|
||||||
self.puppet_arm_publish_thread = None
|
|
||||||
self.puppet_arm_publish_lock = None
|
|
||||||
self.args = args
|
|
||||||
self.ctrl_state = False
|
|
||||||
self.ctrl_state_lock = threading.Lock()
|
|
||||||
self.init()
|
|
||||||
self.init_ros()
|
|
||||||
|
|
||||||
def init(self):
|
|
||||||
self.bridge = CvBridge()
|
|
||||||
self.img_left_deque = deque()
|
|
||||||
self.img_right_deque = deque()
|
|
||||||
self.img_front_deque = deque()
|
|
||||||
self.img_left_depth_deque = deque()
|
|
||||||
self.img_right_depth_deque = deque()
|
|
||||||
self.img_front_depth_deque = deque()
|
|
||||||
self.puppet_arm_left_deque = deque()
|
|
||||||
self.puppet_arm_right_deque = deque()
|
|
||||||
self.robot_base_deque = deque()
|
|
||||||
self.puppet_arm_publish_lock = threading.Lock()
|
|
||||||
self.puppet_arm_publish_lock.acquire()
|
|
||||||
|
|
||||||
def puppet_arm_publish(self, left, right):
|
|
||||||
joint_state_msg = JointState()
|
|
||||||
joint_state_msg.header = Header()
|
|
||||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
|
||||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
|
||||||
joint_state_msg.position = left
|
|
||||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
||||||
joint_state_msg.position = right
|
|
||||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
||||||
|
|
||||||
def robot_base_publish(self, vel):
|
|
||||||
vel_msg = Twist()
|
|
||||||
vel_msg.linear.x = vel[0]
|
|
||||||
vel_msg.linear.y = 0
|
|
||||||
vel_msg.linear.z = 0
|
|
||||||
vel_msg.angular.x = 0
|
|
||||||
vel_msg.angular.y = 0
|
|
||||||
vel_msg.angular.z = vel[1]
|
|
||||||
self.robot_base_publisher.publish(vel_msg)
|
|
||||||
|
|
||||||
def puppet_arm_publish_continuous(self, left, right):
|
|
||||||
rate = rospy.Rate(self.args.publish_rate)
|
|
||||||
left_arm = None
|
|
||||||
right_arm = None
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
if len(self.puppet_arm_left_deque) != 0:
|
|
||||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
|
||||||
if len(self.puppet_arm_right_deque) != 0:
|
|
||||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
|
||||||
if left_arm is None or right_arm is None:
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
|
||||||
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
|
||||||
flag = True
|
|
||||||
step = 0
|
|
||||||
while flag and not rospy.is_shutdown():
|
|
||||||
if self.puppet_arm_publish_lock.acquire(False):
|
|
||||||
return
|
|
||||||
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
|
||||||
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
|
||||||
flag = False
|
|
||||||
for i in range(len(left)):
|
|
||||||
if left_diff[i] < self.args.arm_steps_length[i]:
|
|
||||||
left_arm[i] = left[i]
|
|
||||||
else:
|
|
||||||
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
|
||||||
flag = True
|
|
||||||
for i in range(len(right)):
|
|
||||||
if right_diff[i] < self.args.arm_steps_length[i]:
|
|
||||||
right_arm[i] = right[i]
|
|
||||||
else:
|
|
||||||
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
|
||||||
flag = True
|
|
||||||
joint_state_msg = JointState()
|
|
||||||
joint_state_msg.header = Header()
|
|
||||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
|
||||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
|
||||||
joint_state_msg.position = left_arm
|
|
||||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
||||||
joint_state_msg.position = right_arm
|
|
||||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
||||||
step += 1
|
|
||||||
print("puppet_arm_publish_continuous:", step)
|
|
||||||
rate.sleep()
|
|
||||||
|
|
||||||
def puppet_arm_publish_linear(self, left, right):
|
|
||||||
num_step = 100
|
|
||||||
rate = rospy.Rate(200)
|
|
||||||
|
|
||||||
left_arm = None
|
|
||||||
right_arm = None
|
|
||||||
|
|
||||||
while True and not rospy.is_shutdown():
|
|
||||||
if len(self.puppet_arm_left_deque) != 0:
|
|
||||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
|
||||||
if len(self.puppet_arm_right_deque) != 0:
|
|
||||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
|
||||||
if left_arm is None or right_arm is None:
|
|
||||||
rate.sleep()
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
traj_left_list = np.linspace(left_arm, left, num_step)
|
|
||||||
traj_right_list = np.linspace(right_arm, right, num_step)
|
|
||||||
|
|
||||||
for i in range(len(traj_left_list)):
|
|
||||||
traj_left = traj_left_list[i]
|
|
||||||
traj_right = traj_right_list[i]
|
|
||||||
traj_left[-1] = left[-1]
|
|
||||||
traj_right[-1] = right[-1]
|
|
||||||
joint_state_msg = JointState()
|
|
||||||
joint_state_msg.header = Header()
|
|
||||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
|
||||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
|
||||||
joint_state_msg.position = traj_left
|
|
||||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
|
||||||
joint_state_msg.position = traj_right
|
|
||||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
|
||||||
rate.sleep()
|
|
||||||
|
|
||||||
def puppet_arm_publish_continuous_thread(self, left, right):
|
|
||||||
if self.puppet_arm_publish_thread is not None:
|
|
||||||
self.puppet_arm_publish_lock.release()
|
|
||||||
self.puppet_arm_publish_thread.join()
|
|
||||||
self.puppet_arm_publish_lock.acquire(False)
|
|
||||||
self.puppet_arm_publish_thread = None
|
|
||||||
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
|
||||||
self.puppet_arm_publish_thread.start()
|
|
||||||
|
|
||||||
def get_frame(self):
|
|
||||||
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
|
|
||||||
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
|
|
||||||
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
|
|
||||||
else:
|
|
||||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
|
|
||||||
|
|
||||||
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
|
|
||||||
return False
|
|
||||||
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
|
|
||||||
return False
|
|
||||||
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
|
|
||||||
return False
|
|
||||||
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
|
||||||
return False
|
|
||||||
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
|
||||||
return False
|
|
||||||
|
|
||||||
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_left_deque.popleft()
|
|
||||||
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
|
|
||||||
|
|
||||||
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_right_deque.popleft()
|
|
||||||
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
|
|
||||||
|
|
||||||
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_front_deque.popleft()
|
|
||||||
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
|
|
||||||
|
|
||||||
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.puppet_arm_left_deque.popleft()
|
|
||||||
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
|
||||||
|
|
||||||
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.puppet_arm_right_deque.popleft()
|
|
||||||
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
|
||||||
|
|
||||||
img_left_depth = None
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_left_depth_deque.popleft()
|
|
||||||
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
|
|
||||||
|
|
||||||
img_right_depth = None
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_right_depth_deque.popleft()
|
|
||||||
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
|
|
||||||
|
|
||||||
img_front_depth = None
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.img_front_depth_deque.popleft()
|
|
||||||
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
|
|
||||||
|
|
||||||
robot_base = None
|
|
||||||
if self.args.use_robot_base:
|
|
||||||
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
|
||||||
self.robot_base_deque.popleft()
|
|
||||||
robot_base = self.robot_base_deque.popleft()
|
|
||||||
|
|
||||||
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
|
||||||
puppet_arm_left, puppet_arm_right, robot_base)
|
|
||||||
|
|
||||||
def img_left_callback(self, msg):
|
|
||||||
if len(self.img_left_deque) >= 2000:
|
|
||||||
self.img_left_deque.popleft()
|
|
||||||
self.img_left_deque.append(msg)
|
|
||||||
|
|
||||||
def img_right_callback(self, msg):
|
|
||||||
if len(self.img_right_deque) >= 2000:
|
|
||||||
self.img_right_deque.popleft()
|
|
||||||
self.img_right_deque.append(msg)
|
|
||||||
|
|
||||||
def img_front_callback(self, msg):
|
|
||||||
if len(self.img_front_deque) >= 2000:
|
|
||||||
self.img_front_deque.popleft()
|
|
||||||
self.img_front_deque.append(msg)
|
|
||||||
|
|
||||||
def img_left_depth_callback(self, msg):
|
|
||||||
if len(self.img_left_depth_deque) >= 2000:
|
|
||||||
self.img_left_depth_deque.popleft()
|
|
||||||
self.img_left_depth_deque.append(msg)
|
|
||||||
|
|
||||||
def img_right_depth_callback(self, msg):
|
|
||||||
if len(self.img_right_depth_deque) >= 2000:
|
|
||||||
self.img_right_depth_deque.popleft()
|
|
||||||
self.img_right_depth_deque.append(msg)
|
|
||||||
|
|
||||||
def img_front_depth_callback(self, msg):
|
|
||||||
if len(self.img_front_depth_deque) >= 2000:
|
|
||||||
self.img_front_depth_deque.popleft()
|
|
||||||
self.img_front_depth_deque.append(msg)
|
|
||||||
|
|
||||||
def puppet_arm_left_callback(self, msg):
|
|
||||||
if len(self.puppet_arm_left_deque) >= 2000:
|
|
||||||
self.puppet_arm_left_deque.popleft()
|
|
||||||
self.puppet_arm_left_deque.append(msg)
|
|
||||||
|
|
||||||
def puppet_arm_right_callback(self, msg):
|
|
||||||
if len(self.puppet_arm_right_deque) >= 2000:
|
|
||||||
self.puppet_arm_right_deque.popleft()
|
|
||||||
self.puppet_arm_right_deque.append(msg)
|
|
||||||
|
|
||||||
def robot_base_callback(self, msg):
|
|
||||||
if len(self.robot_base_deque) >= 2000:
|
|
||||||
self.robot_base_deque.popleft()
|
|
||||||
self.robot_base_deque.append(msg)
|
|
||||||
|
|
||||||
def ctrl_callback(self, msg):
|
|
||||||
self.ctrl_state_lock.acquire()
|
|
||||||
self.ctrl_state = msg.data
|
|
||||||
self.ctrl_state_lock.release()
|
|
||||||
|
|
||||||
def get_ctrl_state(self):
|
|
||||||
self.ctrl_state_lock.acquire()
|
|
||||||
state = self.ctrl_state
|
|
||||||
self.ctrl_state_lock.release()
|
|
||||||
return state
|
|
||||||
|
|
||||||
def init_ros(self):
|
|
||||||
rospy.init_node('joint_state_publisher', anonymous=True)
|
|
||||||
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
if self.args.use_depth_image:
|
|
||||||
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
|
|
||||||
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
|
||||||
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
|
|
||||||
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
|
||||||
|
|
||||||
|
|
||||||
def get_arguments():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
|
|
||||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False)
|
|
||||||
parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False)
|
|
||||||
parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False)
|
|
||||||
parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False)
|
|
||||||
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False)
|
|
||||||
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False)
|
|
||||||
parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False)
|
|
||||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False)
|
|
||||||
parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False)
|
|
||||||
parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False)
|
|
||||||
parser.add_argument('--dilation', action='store_true',
|
|
||||||
help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False)
|
|
||||||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
|
||||||
help="Type of positional embedding to use on top of the image features", required=False)
|
|
||||||
parser.add_argument('--masks', action='store_true',
|
|
||||||
help="Train segmentation head if the flag is provided")
|
|
||||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False)
|
|
||||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False)
|
|
||||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False)
|
|
||||||
parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False)
|
|
||||||
|
|
||||||
parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False)
|
|
||||||
parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False)
|
|
||||||
parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False)
|
|
||||||
parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False)
|
|
||||||
parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False)
|
|
||||||
parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False)
|
|
||||||
parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False)
|
|
||||||
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False)
|
|
||||||
parser.add_argument('--pre_norm', action='store_true', required=False)
|
|
||||||
|
|
||||||
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
|
|
||||||
default='/camera_f/color/image_raw', required=False)
|
|
||||||
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
|
|
||||||
default='/camera_l/color/image_raw', required=False)
|
|
||||||
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
|
|
||||||
default='/camera_r/color/image_raw', required=False)
|
|
||||||
|
|
||||||
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
|
|
||||||
default='/camera_f/depth/image_raw', required=False)
|
|
||||||
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
|
|
||||||
default='/camera_l/depth/image_raw', required=False)
|
|
||||||
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
|
|
||||||
default='/camera_r/depth/image_raw', required=False)
|
|
||||||
|
|
||||||
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
|
|
||||||
default='/master/joint_left', required=False)
|
|
||||||
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
|
|
||||||
default='/master/joint_right', required=False)
|
|
||||||
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
|
|
||||||
default='/puppet/joint_left', required=False)
|
|
||||||
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
|
|
||||||
default='/puppet/joint_right', required=False)
|
|
||||||
|
|
||||||
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
|
|
||||||
default='/odom_raw', required=False)
|
|
||||||
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
|
|
||||||
default='/cmd_vel', required=False)
|
|
||||||
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
|
|
||||||
default=False, required=False)
|
|
||||||
parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate',
|
|
||||||
default=40, required=False)
|
|
||||||
parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step',
|
|
||||||
default=0, required=False)
|
|
||||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size',
|
|
||||||
default=32, required=False)
|
|
||||||
parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length',
|
|
||||||
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
|
|
||||||
|
|
||||||
parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation',
|
|
||||||
default=False, required=False)
|
|
||||||
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
|
|
||||||
default=False, required=False)
|
|
||||||
|
|
||||||
# for Diffusion
|
|
||||||
parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False)
|
|
||||||
parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False)
|
|
||||||
parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False)
|
|
||||||
parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_arguments()
|
|
||||||
ros_operator = RosOperator(args)
|
|
||||||
config = get_model_config(args)
|
|
||||||
model_inference(args, config, ros_operator, save_episode=True)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
# python act/inference.py --ckpt_dir ~/train0314/
|
|
||||||
207
collect_data/replay_data.py
Normal file → Executable file
207
collect_data/replay_data.py
Normal file → Executable file
@@ -10,24 +10,63 @@ from cv_bridge import CvBridge
|
|||||||
from std_msgs.msg import Header
|
from std_msgs.msg import Header
|
||||||
from sensor_msgs.msg import Image, JointState
|
from sensor_msgs.msg import Image, JointState
|
||||||
from geometry_msgs.msg import Twist
|
from geometry_msgs.msg import Twist
|
||||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
sys.path.append("./")
|
||||||
|
def load_hdf5(dataset_dir, dataset_name):
|
||||||
|
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
|
is_sim = root.attrs['sim']
|
||||||
|
compressed = root.attrs.get('compress', False)
|
||||||
|
qpos = root['/observations/qpos'][()]
|
||||||
|
qvel = root['/observations/qvel'][()]
|
||||||
|
if 'effort' in root.keys():
|
||||||
|
effort = root['/observations/effort'][()]
|
||||||
|
else:
|
||||||
|
effort = None
|
||||||
|
action = root['/action'][()]
|
||||||
|
base_action = root['/base_action'][()]
|
||||||
|
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in root[f'/observations/images/'].keys():
|
||||||
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||||
|
|
||||||
|
if compressed:
|
||||||
|
compress_len = root['/compress_len'][()]
|
||||||
|
|
||||||
|
if compressed:
|
||||||
|
for cam_id, cam_name in enumerate(image_dict.keys()):
|
||||||
|
# un-pad and uncompress
|
||||||
|
padded_compressed_image_list = image_dict[cam_name]
|
||||||
|
image_list = []
|
||||||
|
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
|
||||||
|
image_len = int(compress_len[cam_id, frame_id])
|
||||||
|
|
||||||
|
compressed_image = padded_compressed_image
|
||||||
|
image = cv2.imdecode(compressed_image, 1)
|
||||||
|
image_list.append(image)
|
||||||
|
image_dict[cam_name] = image_list
|
||||||
|
|
||||||
|
return qpos, qvel, effort, action, base_action, image_dict
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
rospy.init_node("replay_node")
|
rospy.init_node("replay_node")
|
||||||
bridge = CvBridge()
|
bridge = CvBridge()
|
||||||
# img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
|
img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
|
||||||
# img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
|
img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
|
||||||
# img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
|
img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
|
||||||
|
|
||||||
# puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
|
puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
|
||||||
# puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
|
puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
|
||||||
|
|
||||||
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
|
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
|
||||||
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
|
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
|
||||||
|
|
||||||
# robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
|
robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
|
||||||
|
|
||||||
|
|
||||||
# dataset_dir = args.dataset_dir
|
# dataset_dir = args.dataset_dir
|
||||||
@@ -35,78 +74,130 @@ def main(args):
|
|||||||
# task_name = args.task_name
|
# task_name = args.task_name
|
||||||
# dataset_name = f'episode_{episode_idx}'
|
# dataset_name = f'episode_{episode_idx}'
|
||||||
|
|
||||||
dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode])
|
|
||||||
actions = dataset.hf_dataset.select_columns("action")
|
|
||||||
velocitys = dataset.hf_dataset.select_columns("observation.velocity")
|
|
||||||
efforts = dataset.hf_dataset.select_columns("observation.effort")
|
|
||||||
|
|
||||||
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
|
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
|
||||||
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
|
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
|
||||||
|
|
||||||
|
|
||||||
joint_state_msg = JointState()
|
joint_state_msg = JointState()
|
||||||
joint_state_msg.header = Header()
|
joint_state_msg.header = Header()
|
||||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称
|
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||||
twist_msg = Twist()
|
twist_msg = Twist()
|
||||||
|
|
||||||
rate = rospy.Rate(args.fps)
|
rate = rospy.Rate(args.frame_rate)
|
||||||
|
|
||||||
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
|
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
|
||||||
|
|
||||||
|
actions = [[ 0.1277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.3238, -0.1094,
|
||||||
|
0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]]
|
||||||
|
|
||||||
last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647]
|
if not args.only_pub_master:
|
||||||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
last_action = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279, 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
|
||||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
rate = rospy.Rate(100)
|
||||||
rate = rospy.Rate(50)
|
for action in actions:
|
||||||
for idx in range(len(actions)):
|
if(rospy.is_shutdown()):
|
||||||
action = actions[idx]['action'].detach().cpu().numpy()
|
break
|
||||||
velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy()
|
|
||||||
effort = efforts[idx]['observation.effort'].detach().cpu().numpy()
|
|
||||||
if(rospy.is_shutdown()):
|
|
||||||
break
|
|
||||||
|
|
||||||
new_actions = np.linspace(last_action, action, 5) # 插值
|
|
||||||
new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值
|
|
||||||
new_efforts = np.linspace(last_effort, effort, 5) # 插值
|
|
||||||
last_action = action
|
|
||||||
last_velocity = velocity
|
|
||||||
last_effort = effort
|
|
||||||
for act in new_actions:
|
|
||||||
print(np.round(act[:7], 4))
|
|
||||||
cur_timestamp = rospy.Time.now() # 设置时间戳
|
|
||||||
joint_state_msg.header.stamp = cur_timestamp
|
|
||||||
|
|
||||||
joint_state_msg.position = act[:7]
|
new_actions = np.linspace(last_action, action, 50) # 插值
|
||||||
joint_state_msg.velocity = last_velocity[:7]
|
last_action = action
|
||||||
joint_state_msg.effort = last_effort[:7]
|
for act in new_actions:
|
||||||
|
print(np.round(act[:7], 4))
|
||||||
|
cur_timestamp = rospy.Time.now() # 设置时间戳
|
||||||
|
joint_state_msg.header.stamp = cur_timestamp
|
||||||
|
|
||||||
|
joint_state_msg.position = act[:7]
|
||||||
|
master_arm_left_publisher.publish(joint_state_msg)
|
||||||
|
|
||||||
|
joint_state_msg.position = act[7:]
|
||||||
|
master_arm_right_publisher.publish(joint_state_msg)
|
||||||
|
|
||||||
|
if(rospy.is_shutdown()):
|
||||||
|
break
|
||||||
|
rate.sleep()
|
||||||
|
|
||||||
|
else:
|
||||||
|
i = 0
|
||||||
|
while(not rospy.is_shutdown() and i < len(actions)):
|
||||||
|
print("left: ", np.round(qposs[i][:7], 4), " right: ", np.round(qposs[i][7:], 4))
|
||||||
|
|
||||||
|
cam_names = [k for k in image_dicts.keys()]
|
||||||
|
image0 = image_dicts[cam_names[0]][i]
|
||||||
|
image0 = image0[:, :, [2, 1, 0]] # swap B and R channel
|
||||||
|
|
||||||
|
image1 = image_dicts[cam_names[1]][i]
|
||||||
|
image1 = image1[:, :, [2, 1, 0]] # swap B and R channel
|
||||||
|
|
||||||
|
image2 = image_dicts[cam_names[2]][i]
|
||||||
|
image2 = image2[:, :, [2, 1, 0]] # swap B and R channel
|
||||||
|
|
||||||
|
cur_timestamp = rospy.Time.now() # 设置时间戳
|
||||||
|
|
||||||
|
joint_state_msg.header.stamp = cur_timestamp
|
||||||
|
joint_state_msg.position = actions[i][:7]
|
||||||
master_arm_left_publisher.publish(joint_state_msg)
|
master_arm_left_publisher.publish(joint_state_msg)
|
||||||
|
|
||||||
joint_state_msg.position = act[7:]
|
joint_state_msg.position = actions[i][7:]
|
||||||
joint_state_msg.velocity = last_velocity[:7]
|
|
||||||
joint_state_msg.effort = last_effort[7:]
|
|
||||||
master_arm_right_publisher.publish(joint_state_msg)
|
master_arm_right_publisher.publish(joint_state_msg)
|
||||||
|
|
||||||
if(rospy.is_shutdown()):
|
joint_state_msg.position = qposs[i][:7]
|
||||||
break
|
puppet_arm_left_publisher.publish(joint_state_msg)
|
||||||
rate.sleep()
|
|
||||||
|
joint_state_msg.position = qposs[i][7:]
|
||||||
|
puppet_arm_right_publisher.publish(joint_state_msg)
|
||||||
|
|
||||||
|
img_front_publisher.publish(bridge.cv2_to_imgmsg(image0, "bgr8"))
|
||||||
|
img_left_publisher.publish(bridge.cv2_to_imgmsg(image1, "bgr8"))
|
||||||
|
img_right_publisher.publish(bridge.cv2_to_imgmsg(image2, "bgr8"))
|
||||||
|
|
||||||
|
|
||||||
|
twist_msg.linear.x = base_actions[i][0]
|
||||||
|
twist_msg.angular.z = base_actions[i][1]
|
||||||
|
robot_base_publisher.publish(twist_msg)
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
rate.sleep()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
# parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
|
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=False)
|
||||||
# default='/master/joint_left', required=False)
|
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
|
||||||
# parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
|
default="aloha_mobile_dummy", required=False)
|
||||||
# default='/master/joint_right', required=False)
|
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
args.repo_id = "tangger/test"
|
|
||||||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
|
|
||||||
args.episode = 1 # replay episode
|
|
||||||
args.master_arm_left_topic = "/master/joint_left"
|
|
||||||
args.master_arm_right_topic = "/master/joint_right"
|
|
||||||
args.fps = 30
|
|
||||||
|
|
||||||
|
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--camera_names', action='store', type=str, help='camera_names',
|
||||||
|
default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
|
||||||
|
default='/camera_f/color/image_raw', required=False)
|
||||||
|
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
|
||||||
|
default='/camera_l/color/image_raw', required=False)
|
||||||
|
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
|
||||||
|
default='/camera_r/color/image_raw', required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
|
||||||
|
default='/master/joint_left', required=False)
|
||||||
|
parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
|
||||||
|
default='/master/joint_right', required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
|
||||||
|
default='/puppet/joint_left', required=False)
|
||||||
|
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
|
||||||
|
default='/puppet/joint_right', required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
|
||||||
|
default='/cmd_vel', required=False)
|
||||||
|
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
|
||||||
|
default=False, required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate',
|
||||||
|
default=30, required=False)
|
||||||
|
|
||||||
|
parser.add_argument('--only_pub_master', action='store_true', help='only_pub_master',required=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0
|
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0
|
||||||
10
collect_data/requirements.txt
Executable file
10
collect_data/requirements.txt
Executable file
@@ -0,0 +1,10 @@
|
|||||||
|
opencv-python==4.9.0.80
|
||||||
|
matplotlib==3.7.5
|
||||||
|
h5py==3.8.0
|
||||||
|
dm-env==1.6
|
||||||
|
numpy==1.23.4
|
||||||
|
pyyaml
|
||||||
|
rospkg==1.5.0
|
||||||
|
catkin-pkg==1.0.0
|
||||||
|
empy==3.3.4
|
||||||
|
PyQt5==5.15.10
|
||||||
@@ -1,372 +0,0 @@
|
|||||||
import yaml
|
|
||||||
from collections import deque
|
|
||||||
import rospy
|
|
||||||
from cv_bridge import CvBridge
|
|
||||||
from typing import Dict, Any, Optional, List
|
|
||||||
from sensor_msgs.msg import Image, JointState
|
|
||||||
from nav_msgs.msg import Odometry
|
|
||||||
import argparse
|
|
||||||
|
|
||||||
|
|
||||||
class Robot:
|
|
||||||
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
|
|
||||||
"""
|
|
||||||
机器人基类,处理通用初始化逻辑
|
|
||||||
Args:
|
|
||||||
config_file: YAML配置文件路径
|
|
||||||
args: 运行时参数
|
|
||||||
"""
|
|
||||||
self._load_config(config_file)
|
|
||||||
self._merge_runtime_args(args)
|
|
||||||
self._init_components()
|
|
||||||
self._init_data_structures()
|
|
||||||
self.init_ros()
|
|
||||||
self.init_features()
|
|
||||||
self.warmup()
|
|
||||||
|
|
||||||
def _load_config(self, config_file: str) -> None:
|
|
||||||
"""加载YAML配置文件"""
|
|
||||||
with open(config_file, 'r') as f:
|
|
||||||
self.config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
|
|
||||||
"""合并运行时参数到配置"""
|
|
||||||
if args is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
runtime_params = {
|
|
||||||
'frame_rate': getattr(args, 'fps', None),
|
|
||||||
'max_timesteps': getattr(args, 'max_timesteps', None),
|
|
||||||
'episode_idx': getattr(args, 'episode_idx', None),
|
|
||||||
'use_depth_image': getattr(args, 'use_depth_image', None),
|
|
||||||
'use_robot_base': getattr(args, 'use_base', None),
|
|
||||||
'video': getattr(args, 'video', False),
|
|
||||||
'control_type': getattr(args, 'control_type', False),
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, value in runtime_params.items():
|
|
||||||
if value is not None:
|
|
||||||
self.config[key] = value
|
|
||||||
|
|
||||||
def _init_components(self) -> None:
|
|
||||||
"""初始化核心组件"""
|
|
||||||
self.bridge = CvBridge()
|
|
||||||
self.subscribers = {}
|
|
||||||
self.publishers = {}
|
|
||||||
self._validate_config()
|
|
||||||
|
|
||||||
def _validate_config(self) -> None:
|
|
||||||
"""验证配置完整性"""
|
|
||||||
required_sections = ['cameras', 'arm']
|
|
||||||
for section in required_sections:
|
|
||||||
if section not in self.config:
|
|
||||||
raise ValueError(f"Missing required config section: {section}")
|
|
||||||
|
|
||||||
def _init_data_structures(self) -> None:
|
|
||||||
"""初始化数据结构模板方法"""
|
|
||||||
# 相机数据
|
|
||||||
self.cameras = self.config.get('cameras', {})
|
|
||||||
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
|
|
||||||
|
|
||||||
# 深度数据
|
|
||||||
self.use_depth_image = self.config.get('use_depth_image', False)
|
|
||||||
if self.use_depth_image:
|
|
||||||
self.sync_depth_queues = {
|
|
||||||
name: deque(maxlen=2000)
|
|
||||||
for name, cam in self.cameras.items()
|
|
||||||
if 'depth_topic_name' in cam
|
|
||||||
}
|
|
||||||
|
|
||||||
# 机械臂数据
|
|
||||||
self.arms = self.config.get('arm', {})
|
|
||||||
if self.config.get('control_type', '') != 'record':
|
|
||||||
# 如果不是录制模式,则仅初始化从机械臂数据队列
|
|
||||||
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
|
|
||||||
else:
|
|
||||||
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
|
|
||||||
|
|
||||||
# 机器人基座数据
|
|
||||||
self.use_robot_base = self.config.get('use_robot_base', False)
|
|
||||||
if self.use_robot_base:
|
|
||||||
self.sync_base_queue = deque(maxlen=2000)
|
|
||||||
|
|
||||||
def init_ros(self) -> None:
|
|
||||||
"""初始化ROS订阅的模板方法"""
|
|
||||||
rospy.init_node(
|
|
||||||
f"{self.config.get('ros_node_name', 'generic_robot_node')}",
|
|
||||||
anonymous=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self._setup_camera_subscribers()
|
|
||||||
self._setup_arm_subscribers_publishers()
|
|
||||||
self._setup_base_subscriber()
|
|
||||||
self._log_ros_status()
|
|
||||||
|
|
||||||
def init_features(self):
|
|
||||||
"""
|
|
||||||
根据YAML配置自动生成features结构
|
|
||||||
"""
|
|
||||||
self.features = {}
|
|
||||||
|
|
||||||
# 初始化相机特征
|
|
||||||
self._init_camera_features()
|
|
||||||
|
|
||||||
# 初始化机械臂特征
|
|
||||||
self._init_state_features()
|
|
||||||
|
|
||||||
self._init_action_features()
|
|
||||||
|
|
||||||
# 初始化基座特征(如果启用)
|
|
||||||
if self.use_robot_base:
|
|
||||||
self._init_base_features()
|
|
||||||
import pprint
|
|
||||||
pprint.pprint(self.features, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
def _init_camera_features(self):
|
|
||||||
"""处理所有相机特征"""
|
|
||||||
for cam_name, cam_config in self.cameras.items():
|
|
||||||
# 普通图像
|
|
||||||
self.features[f"observation.images.{cam_name}"] = {
|
|
||||||
"dtype": "video" if self.config.get("video", False) else "image",
|
|
||||||
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
|
|
||||||
"names": ["height", "width", "channel"],
|
|
||||||
# "video_info": {
|
|
||||||
# "video.fps": cam_config.get("fps", 30.0),
|
|
||||||
# "video.codec": cam_config.get("codec", "av1"),
|
|
||||||
# "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"),
|
|
||||||
# "video.is_depth_map": False,
|
|
||||||
# "has_audio": False
|
|
||||||
# }
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.config.get("use_depth_image", False):
|
|
||||||
self.features[f"observation.images.depth_{cam_name}"] = {
|
|
||||||
"dtype": "uint16",
|
|
||||||
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
|
|
||||||
"names": ["height", "width"],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _init_state_features(self):
|
|
||||||
state = self.config.get('state', {})
|
|
||||||
# 状态特征
|
|
||||||
self.features["observation.state"] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(state.get('motors', "")),),
|
|
||||||
"names": {"motors": state.get('motors', "")}
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.config.get('velocity'):
|
|
||||||
velocity = self.config.get('velocity', "")
|
|
||||||
self.features["observation.velocity"] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(velocity.get('motors', "")),),
|
|
||||||
"names": {"motors": velocity.get('motors', "")}
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.config.get('effort'):
|
|
||||||
effort = self.config.get('effort', "")
|
|
||||||
self.features["observation.effort"] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(effort.get('motors', "")),),
|
|
||||||
"names": {"motors": effort.get('motors', "")}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _init_action_features(self):
|
|
||||||
action = self.config.get('action', {})
|
|
||||||
# 状态特征
|
|
||||||
self.features["action"] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (len(action.get('motors', "")),),
|
|
||||||
"names": {"motors": action.get('motors', "")}
|
|
||||||
}
|
|
||||||
|
|
||||||
def _init_base_features(self):
|
|
||||||
"""处理基座特征"""
|
|
||||||
self.features["observation.base_vel"] = {
|
|
||||||
"dtype": "float32",
|
|
||||||
"shape": (2,),
|
|
||||||
"names": ["linear_x", "angular_z"]
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_camera_subscribers(self) -> None:
|
|
||||||
"""设置相机订阅者"""
|
|
||||||
for cam_name, cam_config in self.cameras.items():
|
|
||||||
if 'img_topic_name' in cam_config:
|
|
||||||
self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber(
|
|
||||||
cam_config['img_topic_name'],
|
|
||||||
Image,
|
|
||||||
self._make_camera_callback(cam_name, is_depth=False),
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.use_depth_image and 'depth_topic_name' in cam_config:
|
|
||||||
self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber(
|
|
||||||
cam_config['depth_topic_name'],
|
|
||||||
Image,
|
|
||||||
self._make_camera_callback(cam_name, is_depth=True),
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _setup_arm_subscribers_publishers(self) -> None:
|
|
||||||
"""设置机械臂订阅者"""
|
|
||||||
# 当为record模式时,主从机械臂都需要订阅
|
|
||||||
# 否则只订阅从机械臂,但向主机械臂发布
|
|
||||||
if self.config.get('control_type', '') == 'record':
|
|
||||||
for arm_name, arm_config in self.arms.items():
|
|
||||||
if 'topic_name' in arm_config:
|
|
||||||
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
|
|
||||||
arm_config['topic_name'],
|
|
||||||
JointState,
|
|
||||||
self._make_arm_callback(arm_name),
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for arm_name, arm_config in self.arms.items():
|
|
||||||
if 'puppet' in arm_name:
|
|
||||||
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
|
|
||||||
arm_config['topic_name'],
|
|
||||||
JointState,
|
|
||||||
self._make_arm_callback(arm_name),
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True
|
|
||||||
)
|
|
||||||
if 'master' in arm_name:
|
|
||||||
self.publishers[f"arm_{arm_name}"] = rospy.Publisher(
|
|
||||||
arm_config['topic_name'],
|
|
||||||
JointState,
|
|
||||||
queue_size=10
|
|
||||||
)
|
|
||||||
|
|
||||||
def _setup_base_subscriber(self) -> None:
|
|
||||||
"""设置基座订阅者"""
|
|
||||||
if self.use_robot_base and 'robot_base' in self.config:
|
|
||||||
self.subscribers['base'] = rospy.Subscriber(
|
|
||||||
self.config['robot_base']['topic_name'],
|
|
||||||
Odometry,
|
|
||||||
self.robot_base_callback,
|
|
||||||
queue_size=1000,
|
|
||||||
tcp_nodelay=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def _log_ros_status(self) -> None:
|
|
||||||
"""记录ROS状态"""
|
|
||||||
rospy.loginfo("\n=== ROS订阅状态 ===")
|
|
||||||
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
|
|
||||||
rospy.loginfo("活跃的订阅者:")
|
|
||||||
for topic, sub in self.subscribers.items():
|
|
||||||
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
|
|
||||||
rospy.loginfo("=================")
|
|
||||||
|
|
||||||
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
|
|
||||||
"""生成相机回调函数工厂方法"""
|
|
||||||
def callback(msg):
|
|
||||||
try:
|
|
||||||
target_queue = (
|
|
||||||
self.sync_depth_queues[cam_name]
|
|
||||||
if is_depth
|
|
||||||
else self.sync_img_queues[cam_name]
|
|
||||||
)
|
|
||||||
if len(target_queue) >= 2000:
|
|
||||||
target_queue.popleft()
|
|
||||||
target_queue.append(msg)
|
|
||||||
except Exception as e:
|
|
||||||
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
|
|
||||||
return callback
|
|
||||||
|
|
||||||
def _make_arm_callback(self, arm_name: str):
|
|
||||||
"""生成机械臂回调函数工厂方法"""
|
|
||||||
def callback(msg):
|
|
||||||
try:
|
|
||||||
if len(self.sync_arm_queues[arm_name]) >= 2000:
|
|
||||||
self.sync_arm_queues[arm_name].popleft()
|
|
||||||
self.sync_arm_queues[arm_name].append(msg)
|
|
||||||
except Exception as e:
|
|
||||||
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
|
|
||||||
return callback
|
|
||||||
|
|
||||||
def robot_base_callback(self, msg):
|
|
||||||
"""基座回调默认实现"""
|
|
||||||
if len(self.sync_base_queue) >= 2000:
|
|
||||||
self.sync_base_queue.popleft()
|
|
||||||
self.sync_base_queue.append(msg)
|
|
||||||
|
|
||||||
def warmup(self, timeout: float = 10.0) -> bool:
|
|
||||||
"""Wait until all data queues have at least 20 messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: Maximum time to wait in seconds before giving up
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if warmup succeeded, False if timed out
|
|
||||||
"""
|
|
||||||
start_time = rospy.Time.now().to_sec()
|
|
||||||
rate = rospy.Rate(10) # Check at 10Hz
|
|
||||||
|
|
||||||
rospy.loginfo("Starting warmup process...")
|
|
||||||
|
|
||||||
while not rospy.is_shutdown():
|
|
||||||
# Check if timeout has been reached
|
|
||||||
current_time = rospy.Time.now().to_sec()
|
|
||||||
if current_time - start_time > timeout:
|
|
||||||
rospy.logwarn("Warmup timed out before all queues were filled")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check all required queues
|
|
||||||
all_ready = True
|
|
||||||
|
|
||||||
# Check camera image queues
|
|
||||||
for cam_name in self.cameras:
|
|
||||||
if len(self.sync_img_queues[cam_name]) < 50:
|
|
||||||
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)")
|
|
||||||
all_ready = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check depth queues if enabled
|
|
||||||
if self.use_depth_image:
|
|
||||||
for cam_name in self.sync_depth_queues:
|
|
||||||
if len(self.sync_depth_queues[cam_name]) < 50:
|
|
||||||
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)")
|
|
||||||
all_ready = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# # Check arm queues
|
|
||||||
# for arm_name in self.arms:
|
|
||||||
# if len(self.sync_arm_queues[arm_name]) < 20:
|
|
||||||
# rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)")
|
|
||||||
# all_ready = False
|
|
||||||
# break
|
|
||||||
|
|
||||||
# Check base queue if enabled
|
|
||||||
if self.use_robot_base:
|
|
||||||
if len(self.sync_base_queue) < 20:
|
|
||||||
rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)")
|
|
||||||
all_ready = False
|
|
||||||
|
|
||||||
# If all queues are ready, return success
|
|
||||||
if all_ready:
|
|
||||||
rospy.loginfo("Warmup completed successfully")
|
|
||||||
return True
|
|
||||||
|
|
||||||
rate.sleep()
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_frame(self) -> Optional[Dict[str, Any]]:
|
|
||||||
"""获取同步帧数据的模板方法"""
|
|
||||||
raise NotImplementedError("Subclasses must implement get_frame()")
|
|
||||||
|
|
||||||
def process(self) -> tuple:
|
|
||||||
"""主处理循环的模板方法"""
|
|
||||||
raise NotImplementedError("Subclasses must implement process()")
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
import yaml
|
|
||||||
import argparse
|
|
||||||
from typing import Dict, List, Any, Optional
|
|
||||||
from rosrobot import Robot
|
|
||||||
from agilex_robot import AgilexRobot
|
|
||||||
|
|
||||||
|
|
||||||
class RobotFactory:
|
|
||||||
@staticmethod
|
|
||||||
def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
|
|
||||||
"""
|
|
||||||
根据配置文件自动创建合适的机器人实例
|
|
||||||
Args:
|
|
||||||
config_file: 配置文件路径
|
|
||||||
args: 运行时参数
|
|
||||||
"""
|
|
||||||
with open(config_file, 'r') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
robot_type = config.get('robot_type', 'agilex')
|
|
||||||
|
|
||||||
if robot_type == 'agilex':
|
|
||||||
return AgilexRobot(config_file, args)
|
|
||||||
# 可扩展其他机器人类型
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported robot type: {robot_type}")
|
|
||||||
322
collect_data/utils.py
Normal file
322
collect_data/utils.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import h5py
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def display_camera_grid(image_dict, grid_shape=None, window_name="MindRobot-V1 Data Collection", scale=1.0):
|
||||||
|
"""
|
||||||
|
显示多摄像头画面(保持原始比例,但可整体缩放)
|
||||||
|
|
||||||
|
参数:
|
||||||
|
image_dict: {摄像头名称: 图像numpy数组}
|
||||||
|
grid_shape: (行, 列) 布局,None自动计算
|
||||||
|
window_name: 窗口名称
|
||||||
|
scale: 整体显示缩放比例(0.5表示显示为原尺寸的50%)
|
||||||
|
"""
|
||||||
|
# 输入验证和数据处理(保持原代码不变)
|
||||||
|
if not isinstance(image_dict, dict):
|
||||||
|
raise TypeError("输入必须是字典类型")
|
||||||
|
|
||||||
|
valid_data = []
|
||||||
|
for name, img in image_dict.items():
|
||||||
|
if not isinstance(img, np.ndarray):
|
||||||
|
continue
|
||||||
|
if img.dtype != np.uint8:
|
||||||
|
img = img.astype(np.uint8)
|
||||||
|
if img.ndim == 2:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
elif img.shape[2] == 4:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||||
|
elif img.shape[2] == 3:
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||||
|
valid_data.append((name, img))
|
||||||
|
|
||||||
|
if not valid_data:
|
||||||
|
print("错误: 没有有效的图像可显示!")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 自动计算网格布局
|
||||||
|
num_valid = len(valid_data)
|
||||||
|
if grid_shape is None:
|
||||||
|
grid_shape = (1, num_valid) if num_valid <= 3 else (2, int(np.ceil(num_valid/2)))
|
||||||
|
|
||||||
|
rows, cols = grid_shape
|
||||||
|
|
||||||
|
# 计算每行/列的最大尺寸
|
||||||
|
row_heights = [0]*rows
|
||||||
|
col_widths = [0]*cols
|
||||||
|
|
||||||
|
for i, (_, img) in enumerate(valid_data[:rows*cols]):
|
||||||
|
r, c = i//cols, i%cols
|
||||||
|
row_heights[r] = max(row_heights[r], img.shape[0])
|
||||||
|
col_widths[c] = max(col_widths[c], img.shape[1])
|
||||||
|
|
||||||
|
# 计算画布总尺寸(应用整体缩放)
|
||||||
|
canvas_h = int(sum(row_heights) * scale)
|
||||||
|
canvas_w = int(sum(col_widths) * scale)
|
||||||
|
|
||||||
|
# 创建画布
|
||||||
|
canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 计算每个子画面的显示区域
|
||||||
|
row_pos = [0] + [int(sum(row_heights[:i+1])*scale) for i in range(rows)]
|
||||||
|
col_pos = [0] + [int(sum(col_widths[:i+1])*scale) for i in range(cols)]
|
||||||
|
|
||||||
|
# 填充图像
|
||||||
|
for i, (name, img) in enumerate(valid_data[:rows*cols]):
|
||||||
|
r, c = i//cols, i%cols
|
||||||
|
|
||||||
|
# 计算当前图像的显示区域
|
||||||
|
x1, x2 = col_pos[c], col_pos[c+1]
|
||||||
|
y1, y2 = row_pos[r], row_pos[r+1]
|
||||||
|
|
||||||
|
# 计算当前图像的缩放后尺寸
|
||||||
|
display_h = int(img.shape[0] * scale)
|
||||||
|
display_w = int(img.shape[1] * scale)
|
||||||
|
|
||||||
|
# 缩放图像(保持比例)
|
||||||
|
resized_img = cv2.resize(img, (display_w, display_h))
|
||||||
|
|
||||||
|
# 放置到画布
|
||||||
|
canvas[y1:y1+display_h, x1:x1+display_w] = resized_img
|
||||||
|
|
||||||
|
# 添加标签(按比例缩放字体)
|
||||||
|
font_scale = 0.8 *scale
|
||||||
|
thickness = max(2, int(2 * scale))
|
||||||
|
cv2.putText(canvas, name, (x1+10, y1+30),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness)
|
||||||
|
|
||||||
|
# 显示窗口(自动适应屏幕)
|
||||||
|
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||||
|
cv2.imshow(window_name, canvas)
|
||||||
|
cv2.resizeWindow(window_name, canvas_w, canvas_h)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
return canvas
|
||||||
|
|
||||||
|
|
||||||
|
# 保存数据函数
|
||||||
|
def save_data(args, timesteps, actions, dataset_path):
|
||||||
|
# 数据字典
|
||||||
|
data_size = len(actions)
|
||||||
|
data_dict = {
|
||||||
|
# 一个是奖励里面的qpos,qvel, effort ,一个是实际发的acition
|
||||||
|
'/observations/qpos': [],
|
||||||
|
'/observations/qvel': [],
|
||||||
|
'/observations/effort': [],
|
||||||
|
'/action': [],
|
||||||
|
'/base_action': [],
|
||||||
|
# '/base_action_t265': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 相机字典 观察的图像
|
||||||
|
for cam_name in args.camera_names:
|
||||||
|
data_dict[f'/observations/images/{cam_name}'] = []
|
||||||
|
if args.use_depth_image:
|
||||||
|
data_dict[f'/observations/images_depth/{cam_name}'] = []
|
||||||
|
|
||||||
|
# len(action): max_timesteps, len(time_steps): max_timesteps + 1
|
||||||
|
# 动作长度 遍历动作
|
||||||
|
while actions:
|
||||||
|
# 循环弹出一个队列
|
||||||
|
action = actions.pop(0) # 动作 当前动作
|
||||||
|
ts = timesteps.pop(0) # 奖励 前一帧
|
||||||
|
|
||||||
|
# 往字典里面添值
|
||||||
|
# Timestep返回的qpos,qvel,effort
|
||||||
|
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||||
|
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||||
|
data_dict['/observations/effort'].append(ts.observation['effort'])
|
||||||
|
|
||||||
|
# 实际发的action
|
||||||
|
data_dict['/action'].append(action)
|
||||||
|
data_dict['/base_action'].append(ts.observation['base_vel'])
|
||||||
|
|
||||||
|
# 相机数据
|
||||||
|
# data_dict['/base_action_t265'].append(ts.observation['base_vel_t265'])
|
||||||
|
for cam_name in args.camera_names:
|
||||||
|
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||||
|
if args.use_depth_image:
|
||||||
|
data_dict[f'/observations/images_depth/{cam_name}'].append(ts.observation['images_depth'][cam_name])
|
||||||
|
|
||||||
|
t0 = time.time()
|
||||||
|
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
|
||||||
|
# 文本的属性:
|
||||||
|
# 1 是否仿真
|
||||||
|
# 2 图像是否压缩
|
||||||
|
#
|
||||||
|
root.attrs['sim'] = False
|
||||||
|
root.attrs['compress'] = False
|
||||||
|
|
||||||
|
# 创建一个新的组observations,观测状态组
|
||||||
|
# 图像组
|
||||||
|
obs = root.create_group('observations')
|
||||||
|
image = obs.create_group('images')
|
||||||
|
for cam_name in args.camera_names:
|
||||||
|
_ = image.create_dataset(cam_name, (data_size, 480, 640, 3), dtype='uint8',
|
||||||
|
chunks=(1, 480, 640, 3), )
|
||||||
|
if args.use_depth_image:
|
||||||
|
image_depth = obs.create_group('images_depth')
|
||||||
|
for cam_name in args.camera_names:
|
||||||
|
_ = image_depth.create_dataset(cam_name, (data_size, 480, 640), dtype='uint16',
|
||||||
|
chunks=(1, 480, 640), )
|
||||||
|
|
||||||
|
_ = obs.create_dataset('qpos', (data_size, 14))
|
||||||
|
_ = obs.create_dataset('qvel', (data_size, 14))
|
||||||
|
_ = obs.create_dataset('effort', (data_size, 14))
|
||||||
|
_ = root.create_dataset('action', (data_size, 14))
|
||||||
|
_ = root.create_dataset('base_action', (data_size, 2))
|
||||||
|
|
||||||
|
# data_dict write into h5py.File
|
||||||
|
for name, array in data_dict.items():
|
||||||
|
root[name][...] = array
|
||||||
|
print(f'\033[32m\nSaving: {time.time() - t0:.1f} secs. %s \033[0m\n'%dataset_path)
|
||||||
|
|
||||||
|
|
||||||
|
def is_headless():
|
||||||
|
"""
|
||||||
|
Check if the environment is headless (no display available).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the environment is headless, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tkinter as tk
|
||||||
|
root = tk.Tk()
|
||||||
|
root.withdraw()
|
||||||
|
root.update()
|
||||||
|
root.destroy()
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def init_keyboard_listener():
|
||||||
|
"""
|
||||||
|
Initialize keyboard listener for control events with new key mappings:
|
||||||
|
- Left arrow: Start data recording
|
||||||
|
- Right arrow: Save current data
|
||||||
|
- Down arrow: Discard current data
|
||||||
|
- Up arrow: Replay current data
|
||||||
|
- ESC: Early termination
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (listener, events) - Keyboard listener and events dictionary
|
||||||
|
"""
|
||||||
|
events = {
|
||||||
|
"exit_early": False,
|
||||||
|
"record_start": False,
|
||||||
|
"save_data": False,
|
||||||
|
"discard_data": False,
|
||||||
|
"replay_data": False
|
||||||
|
}
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
print(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
return None, events
|
||||||
|
|
||||||
|
# Only import pynput if not in a headless environment
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.left:
|
||||||
|
print("← Left arrow: STARTING data recording...")
|
||||||
|
events.update({
|
||||||
|
"record_start": True,
|
||||||
|
"exit_early": False,
|
||||||
|
"save_data": False,
|
||||||
|
"discard_data": False
|
||||||
|
})
|
||||||
|
|
||||||
|
elif key == keyboard.Key.right:
|
||||||
|
print("→ Right arrow: SAVING current data...")
|
||||||
|
events.update({
|
||||||
|
"save_data": True,
|
||||||
|
"exit_early": False,
|
||||||
|
"record_start": False
|
||||||
|
})
|
||||||
|
|
||||||
|
elif key == keyboard.Key.down:
|
||||||
|
print("↓ Down arrow: DISCARDING current data...")
|
||||||
|
events.update({
|
||||||
|
"discard_data": True,
|
||||||
|
"exit_early": False,
|
||||||
|
"record_start": False
|
||||||
|
})
|
||||||
|
|
||||||
|
elif key == keyboard.Key.up:
|
||||||
|
print("↑ Up arrow: REPLAYING current data...")
|
||||||
|
events.update({
|
||||||
|
"replay_data": True,
|
||||||
|
"exit_early": False
|
||||||
|
})
|
||||||
|
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
print("ESC: EARLY TERMINATION requested")
|
||||||
|
events.update({
|
||||||
|
"exit_early": True,
|
||||||
|
"record_start": False
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from argparse import Namespace
|
||||||
|
def load_config(yaml_path):
|
||||||
|
"""Load configuration from YAML file and return as Namespace object"""
|
||||||
|
with open(yaml_path, 'r') as f:
|
||||||
|
config_dict = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# Convert dict to Namespace (similar to argparse.Namespace)
|
||||||
|
return Namespace(**config_dict)
|
||||||
|
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
# import pyttsx3
|
||||||
|
def say(text, blocking=False):
|
||||||
|
system = platform.system()
|
||||||
|
|
||||||
|
if system == "Darwin":
|
||||||
|
cmd = ["say", text]
|
||||||
|
|
||||||
|
elif system == "Linux":
|
||||||
|
# cmd = ["spd-say", text]
|
||||||
|
# if blocking:
|
||||||
|
# cmd.append("--wait")
|
||||||
|
cmd = ["edge-playback", "--text", text]
|
||||||
|
|
||||||
|
elif system == "Windows":
|
||||||
|
cmd = [
|
||||||
|
"PowerShell",
|
||||||
|
"-Command",
|
||||||
|
"Add-Type -AssemblyName System.Speech; "
|
||||||
|
f"(New-Object System.Speech.Synthesis.SpeechSynthesizer).Speak('{text}')",
|
||||||
|
]
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unsupported operating system for text-to-speech.")
|
||||||
|
|
||||||
|
if blocking:
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
else:
|
||||||
|
subprocess.Popen(cmd, creationflags=subprocess.CREATE_NO_WINDOW if system == "Windows" else 0)
|
||||||
|
|
||||||
|
|
||||||
|
def log_say(text, play_sounds, blocking=False):
|
||||||
|
print(text)
|
||||||
|
|
||||||
|
if play_sounds:
|
||||||
|
say(text, blocking)
|
||||||
159
collect_data/visualize_episodes.py
Executable file
159
collect_data/visualize_episodes.py
Executable file
@@ -0,0 +1,159 @@
|
|||||||
|
#coding=utf-8
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import h5py
|
||||||
|
import argparse
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
DT = 0.02
|
||||||
|
# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||||
|
JOINT_NAMES = ["joint0", "joint1", "joint2", "joint3", "joint4", "joint5"]
|
||||||
|
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||||
|
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]
|
||||||
|
|
||||||
|
def load_hdf5(dataset_dir, dataset_name):
|
||||||
|
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||||
|
if not os.path.isfile(dataset_path):
|
||||||
|
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||||
|
exit()
|
||||||
|
|
||||||
|
with h5py.File(dataset_path, 'r') as root:
|
||||||
|
is_sim = root.attrs['sim']
|
||||||
|
compressed = root.attrs.get('compress', False)
|
||||||
|
qpos = root['/observations/qpos'][()]
|
||||||
|
qvel = root['/observations/qvel'][()]
|
||||||
|
if 'effort' in root.keys():
|
||||||
|
effort = root['/observations/effort'][()]
|
||||||
|
else:
|
||||||
|
effort = None
|
||||||
|
action = root['/action'][()]
|
||||||
|
base_action = root['/base_action'][()]
|
||||||
|
image_dict = dict()
|
||||||
|
for cam_name in root[f'/observations/images/'].keys():
|
||||||
|
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||||
|
if compressed:
|
||||||
|
compress_len = root['/compress_len'][()]
|
||||||
|
|
||||||
|
if compressed:
|
||||||
|
for cam_id, cam_name in enumerate(image_dict.keys()):
|
||||||
|
# un-pad and uncompress
|
||||||
|
padded_compressed_image_list = image_dict[cam_name]
|
||||||
|
image_list = []
|
||||||
|
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
|
||||||
|
image_len = int(compress_len[cam_id, frame_id])
|
||||||
|
compressed_image = padded_compressed_image
|
||||||
|
image = cv2.imdecode(compressed_image, 1)
|
||||||
|
image_list.append(image)
|
||||||
|
image_dict[cam_name] = image_list
|
||||||
|
|
||||||
|
return qpos, qvel, effort, action, base_action, image_dict
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
dataset_dir = args['dataset_dir']
|
||||||
|
episode_idx = args['episode_idx']
|
||||||
|
task_name = args['task_name']
|
||||||
|
dataset_name = f'episode_{episode_idx}'
|
||||||
|
|
||||||
|
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
|
||||||
|
|
||||||
|
print('hdf5 loaded!!')
|
||||||
|
|
||||||
|
save_videos(image_dict, action, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||||
|
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
|
||||||
|
|
||||||
|
def save_videos(video, actions, dt, video_path=None):
|
||||||
|
cam_names = list(video.keys())
|
||||||
|
all_cam_videos = []
|
||||||
|
for cam_name in cam_names:
|
||||||
|
all_cam_videos.append(video[cam_name])
|
||||||
|
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||||
|
|
||||||
|
n_frames, h, w, _ = all_cam_videos.shape
|
||||||
|
fps = int(1 / dt)
|
||||||
|
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||||
|
for t in range(n_frames):
|
||||||
|
image = all_cam_videos[t]
|
||||||
|
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||||
|
cv2.imshow("images",image)
|
||||||
|
cv2.waitKey(30)
|
||||||
|
print("episode_id: ", t, "left: ", np.round(actions[t][:7], 3), "right: ", np.round(actions[t][7:], 3), "\n")
|
||||||
|
out.write(image)
|
||||||
|
out.release()
|
||||||
|
print(f'Saved video to: {video_path}')
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||||
|
if label_overwrite:
|
||||||
|
label1, label2 = label_overwrite
|
||||||
|
else:
|
||||||
|
label1, label2 = 'State', 'Command'
|
||||||
|
|
||||||
|
qpos = np.array(qpos_list) # ts, dim
|
||||||
|
command = np.array(command_list)
|
||||||
|
|
||||||
|
num_ts, num_dim = qpos.shape
|
||||||
|
h, w = 2, num_dim
|
||||||
|
num_figs = num_dim
|
||||||
|
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
|
||||||
|
|
||||||
|
# plot joint state
|
||||||
|
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.plot(qpos[:, dim_idx], label=label1, color='orangered')
|
||||||
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
# plot arm command
|
||||||
|
# for dim_idx in range(num_dim):
|
||||||
|
# ax = axs[dim_idx]
|
||||||
|
# ax.plot(command[:, dim_idx], label=label2)
|
||||||
|
# ax.legend()
|
||||||
|
|
||||||
|
if ylim:
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.set_ylim(ylim)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
print(f'Saved qpos plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_base(readings, plot_path=None):
|
||||||
|
readings = np.array(readings) # ts, dim
|
||||||
|
num_ts, num_dim = readings.shape
|
||||||
|
num_figs = num_dim
|
||||||
|
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
|
||||||
|
|
||||||
|
# plot joint state
|
||||||
|
all_names = BASE_STATE_NAMES
|
||||||
|
for dim_idx in range(num_dim):
|
||||||
|
ax = axs[dim_idx]
|
||||||
|
ax.plot(readings[:, dim_idx], label='raw')
|
||||||
|
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
|
||||||
|
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
|
||||||
|
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
|
||||||
|
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(plot_path)
|
||||||
|
print(f'Saved effort plot to: {plot_path}')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||||
|
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
|
||||||
|
default="aloha_mobile_dummy", required=False)
|
||||||
|
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
|
||||||
|
|
||||||
|
main(vars(parser.parse_args()))
|
||||||
1
lerobot
Submodule
1
lerobot
Submodule
Submodule lerobot added at 1c873df5c0
18
lerobot_aloha/README.MD
Normal file
18
lerobot_aloha/README.MD
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
|
||||||
|
|
||||||
|
# fd token
|
||||||
|
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
|
||||||
|
|
||||||
|
# act
|
||||||
|
python lerobot/lerobot/scripts/train.py --policy.type=act --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_the_scale_with_head_camera/ --dataset.repo_id=maic/move_tube_on_scale_head --job_name=act_with_head --output_dir=outputs/train/act_move_bottle_on_scale_with_head
|
||||||
|
|
||||||
|
python lerobot/lerobot/scripts/visualize_dataset_html.py --root /home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_a_scale_without_head_camera --repo-id xxx
|
||||||
|
|
||||||
|
# pi0 ft
|
||||||
|
python lerobot/lerobot/scripts/train.py \
|
||||||
|
--policy.path=lerobot/pi0 \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--dataset.root=/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_reagent_bottle_on_a_scale_without_head_camera \
|
||||||
|
--dataset.repo_id=maic/move_a_reagent_bottle_on_a_scale_without_head_camera \
|
||||||
|
--job_name=pi0_without_head \
|
||||||
|
--output_dir=outputs/train/move_a_reagent_bottle_on_a_scale_without_head_camera
|
||||||
BIN
lerobot_aloha/__pycache__/main.cpython-310.pyc
Normal file
BIN
lerobot_aloha/__pycache__/main.cpython-310.pyc
Normal file
Binary file not shown.
BIN
lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc
Normal file
BIN
lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc
Normal file
BIN
lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
@@ -1,17 +1,13 @@
|
|||||||
import yaml
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import collections
|
import collections
|
||||||
import dm_env
|
import dm_env
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Dict, List, Any, Optional
|
from typing import Dict, List, Any, Optional
|
||||||
from collections import deque
|
|
||||||
import rospy
|
import rospy
|
||||||
from cv_bridge import CvBridge
|
|
||||||
from std_msgs.msg import Header
|
from std_msgs.msg import Header
|
||||||
from sensor_msgs.msg import Image, JointState
|
from sensor_msgs.msg import JointState
|
||||||
from nav_msgs.msg import Odometry
|
from .rosrobot import Robot
|
||||||
from rosrobot import Robot
|
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
|
||||||
@@ -40,10 +36,15 @@ class AgilexRobot(Robot):
|
|||||||
# print("can not get data from puppet topic")
|
# print("can not get data from puppet topic")
|
||||||
# return None
|
# return None
|
||||||
|
|
||||||
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
|
# 检查必要的机械臂数据是否可用
|
||||||
print("can not get data from puppet topic")
|
required_arms = ['puppet_left', 'puppet_right']
|
||||||
return None
|
for arm_name in required_arms:
|
||||||
|
if arm_name not in self.sync_arm_queues or len(self.sync_arm_queues[arm_name]) == 0:
|
||||||
|
print(f"can not get data from {arm_name} topic")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 时间戳误差
|
||||||
|
tolerance = 0.1 # 允许 100ms 的时间戳偏差
|
||||||
# 计算最小时间戳
|
# 计算最小时间戳
|
||||||
timestamps = [
|
timestamps = [
|
||||||
q[-1].header.stamp.to_sec()
|
q[-1].header.stamp.to_sec()
|
||||||
@@ -59,18 +60,16 @@ class AgilexRobot(Robot):
|
|||||||
|
|
||||||
min_time = min(timestamps)
|
min_time = min(timestamps)
|
||||||
|
|
||||||
# 检查数据同步性
|
# 检查数据同步性(允许 100ms 偏差)
|
||||||
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
||||||
if queue[-1].header.stamp.to_sec() < min_time:
|
if queue[-1].header.stamp.to_sec() < min_time - tolerance:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.use_depth_image:
|
if self.use_depth_image:
|
||||||
for queue in self.sync_depth_queues.values():
|
for queue in self.sync_depth_queues.values():
|
||||||
if queue[-1].header.stamp.to_sec() < min_time:
|
if queue[-1].header.stamp.to_sec() < min_time - tolerance:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
|
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time - tolerance:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 提取同步数据
|
# 提取同步数据
|
||||||
@@ -82,33 +81,35 @@ class AgilexRobot(Robot):
|
|||||||
|
|
||||||
# 图像数据
|
# 图像数据
|
||||||
for cam_name, queue in self.sync_img_queues.items():
|
for cam_name, queue in self.sync_img_queues.items():
|
||||||
while queue[0].header.stamp.to_sec() < min_time:
|
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||||
queue.popleft()
|
queue.popleft()
|
||||||
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
if queue:
|
||||||
|
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||||
|
|
||||||
# 深度数据
|
# 深度数据
|
||||||
if self.use_depth_image:
|
if self.use_depth_image:
|
||||||
frame_data['depths'] = {}
|
|
||||||
for cam_name, queue in self.sync_depth_queues.items():
|
for cam_name, queue in self.sync_depth_queues.items():
|
||||||
while queue[0].header.stamp.to_sec() < min_time:
|
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||||
queue.popleft()
|
queue.popleft()
|
||||||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
if queue:
|
||||||
# 保持原有的边界填充
|
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||||||
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
|
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
|
||||||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# 机械臂数据
|
# 机械臂数据
|
||||||
for arm_name, queue in self.sync_arm_queues.items():
|
for arm_name, queue in self.sync_arm_queues.items():
|
||||||
while queue[0].header.stamp.to_sec() < min_time:
|
while queue and queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||||
queue.popleft()
|
queue.popleft()
|
||||||
frame_data['arms'][arm_name] = queue.popleft()
|
if queue:
|
||||||
|
frame_data['arms'][arm_name] = queue.popleft()
|
||||||
|
|
||||||
# 基座数据
|
# 基座数据
|
||||||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||||||
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
|
while self.sync_base_queue and self.sync_base_queue[0].header.stamp.to_sec() < min_time - tolerance:
|
||||||
self.sync_base_queue.popleft()
|
self.sync_base_queue.popleft()
|
||||||
frame_data['base'] = self.sync_base_queue.popleft()
|
if self.sync_base_queue:
|
||||||
|
frame_data['base'] = self.sync_base_queue.popleft()
|
||||||
|
|
||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
@@ -126,7 +127,12 @@ class AgilexRobot(Robot):
|
|||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
|
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
|
||||||
|
# 遍历字典,检查并报告每个空队列
|
||||||
|
for arm_name, queue in self.sync_arm_queues.items():
|
||||||
|
if len(queue) == 0:
|
||||||
|
print(f"{arm_name} arm not connected or queue is empty")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
# 计算最小时间戳
|
# 计算最小时间戳
|
||||||
timestamps = [
|
timestamps = [
|
||||||
@@ -206,11 +212,11 @@ class AgilexRobot(Robot):
|
|||||||
if arm_states:
|
if arm_states:
|
||||||
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
|
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
|
||||||
|
|
||||||
if arm_velocity:
|
# if arm_velocity:
|
||||||
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
|
# obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
|
||||||
|
|
||||||
if arm_effort:
|
# if arm_effort:
|
||||||
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
|
# obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
|
||||||
|
|
||||||
if actions:
|
if actions:
|
||||||
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
|
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
|
||||||
@@ -272,7 +278,7 @@ class AgilexRobot(Robot):
|
|||||||
|
|
||||||
# Log timing information
|
# Log timing information
|
||||||
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
|
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
|
||||||
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
|
# print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
|
||||||
|
|
||||||
# Combine all arm states into single tensor
|
# Combine all arm states into single tensor
|
||||||
if arm_states:
|
if arm_states:
|
||||||
@@ -295,7 +301,7 @@ class AgilexRobot(Robot):
|
|||||||
|
|
||||||
# Log timing information
|
# Log timing information
|
||||||
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
|
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
|
||||||
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
|
# print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
|
||||||
|
|
||||||
# Process depth data if enabled
|
# Process depth data if enabled
|
||||||
if self.use_depth_image and 'depths' in frame_data:
|
if self.use_depth_image and 'depths' in frame_data:
|
||||||
@@ -307,7 +313,7 @@ class AgilexRobot(Robot):
|
|||||||
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
|
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
|
||||||
|
|
||||||
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
|
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
|
||||||
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
|
# print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
|
||||||
|
|
||||||
# Process base velocity if enabled
|
# Process base velocity if enabled
|
||||||
if self.use_robot_base and 'base' in frame_data:
|
if self.use_robot_base and 'base' in frame_data:
|
||||||
@@ -330,12 +336,18 @@ class AgilexRobot(Robot):
|
|||||||
Returns:
|
Returns:
|
||||||
The actual action that was sent (may be clipped if safety checks are implemented)
|
The actual action that was sent (may be clipped if safety checks are implemented)
|
||||||
"""
|
"""
|
||||||
# if not hasattr(self, 'puppet_arm_publishers'):
|
# 默认速度和力矩值
|
||||||
# # Initialize publishers on first call
|
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
|
||||||
# self._init_action_publishers()
|
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
||||||
|
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
||||||
|
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||||
|
-0.03296661376953125, -0.03296661376953125]
|
||||||
|
|
||||||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
|
||||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
|
||||||
|
0.0, -0.010990142822265625, -0.010990142822265625,
|
||||||
|
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
||||||
|
-0.03296661376953125, -0.03296661376953125]
|
||||||
|
|
||||||
# Convert tensor to numpy array if needed
|
# Convert tensor to numpy array if needed
|
||||||
if isinstance(action, torch.Tensor):
|
if isinstance(action, torch.Tensor):
|
||||||
@@ -359,22 +371,23 @@ class AgilexRobot(Robot):
|
|||||||
arm_velocity = last_velocity[from_idx:to_idx]
|
arm_velocity = last_velocity[from_idx:to_idx]
|
||||||
arm_effort = last_effort[from_idx:to_idx]
|
arm_effort = last_effort[from_idx:to_idx]
|
||||||
from_idx = to_idx
|
from_idx = to_idx
|
||||||
|
|
||||||
|
# fix
|
||||||
|
arm_action[-1] = max(arm_action[-1]*15, 0)
|
||||||
|
|
||||||
# Apply safety checks if configured
|
# Apply safety checks if configured
|
||||||
if 'max_relative_target' in self.config:
|
|
||||||
# Get current position from the queue
|
# # Get current position from the queue
|
||||||
if len(self.sync_arm_queues[arm_name]) > 0:
|
# if len(arm_action) > 0:
|
||||||
current_state = self.sync_arm_queues[arm_name][-1]
|
|
||||||
current_pos = np.array(current_state.position)
|
# # Clip the action to stay within max relative target
|
||||||
|
# max_delta = 0.1
|
||||||
# Clip the action to stay within max relative target
|
# clipped_action = np.clip(arm_action,
|
||||||
max_delta = self.config['max_relative_target']
|
# arm_action - max_delta,
|
||||||
clipped_action = np.clip(arm_action,
|
# arm_action + max_delta)
|
||||||
current_pos - max_delta,
|
# arm_action = clipped_action
|
||||||
current_pos + max_delta)
|
|
||||||
arm_action = clipped_action
|
|
||||||
|
|
||||||
action_sent.append(arm_action)
|
# action_sent.append(arm_action)
|
||||||
|
|
||||||
# Create and publish JointState message
|
# Create and publish JointState message
|
||||||
joint_state = JointState()
|
joint_state = JointState()
|
||||||
422
lerobot_aloha/common/robot_components.py
Normal file
422
lerobot_aloha/common/robot_components.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
import yaml
|
||||||
|
from collections import deque
|
||||||
|
import rospy
|
||||||
|
from cv_bridge import CvBridge
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
from sensor_msgs.msg import Image, JointState
|
||||||
|
from nav_msgs.msg import Odometry
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
class RobotConfig:
|
||||||
|
"""Configuration management for robot components"""
|
||||||
|
|
||||||
|
def __init__(self, config_file: str):
|
||||||
|
"""
|
||||||
|
Initialize robot configuration from YAML file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: Path to YAML configuration file
|
||||||
|
"""
|
||||||
|
self.config = self._load_yaml(config_file)
|
||||||
|
self._validate_config()
|
||||||
|
|
||||||
|
def _load_yaml(self, config_file: str) -> Dict[str, Any]:
|
||||||
|
"""Load configuration from YAML file"""
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
def _validate_config(self) -> None:
|
||||||
|
"""Validate configuration completeness"""
|
||||||
|
required_sections = ['cameras', 'arm']
|
||||||
|
for section in required_sections:
|
||||||
|
if section not in self.config:
|
||||||
|
raise ValueError(f"Missing required config section: {section}")
|
||||||
|
|
||||||
|
def merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
|
||||||
|
"""
|
||||||
|
Merge runtime arguments into configuration
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Runtime arguments from command line
|
||||||
|
"""
|
||||||
|
if args is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
runtime_params = {
|
||||||
|
'frame_rate': getattr(args, 'fps', None),
|
||||||
|
'max_timesteps': getattr(args, 'max_timesteps', None),
|
||||||
|
'episode_idx': getattr(args, 'episode_idx', None),
|
||||||
|
'use_depth_image': getattr(args, 'use_depth_image', None),
|
||||||
|
'use_robot_base': getattr(args, 'use_base', None),
|
||||||
|
'video': getattr(args, 'video', False),
|
||||||
|
'control_type': getattr(args, 'control_type', False),
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, value in runtime_params.items():
|
||||||
|
if value is not None:
|
||||||
|
self.config[key] = value
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> Any:
|
||||||
|
"""Get configuration value with default fallback"""
|
||||||
|
return self.config.get(key, default)
|
||||||
|
|
||||||
|
|
||||||
|
class RosAdapter:
|
||||||
|
"""Adapter for ROS communication"""
|
||||||
|
|
||||||
|
def __init__(self, config: RobotConfig):
|
||||||
|
"""
|
||||||
|
Initialize ROS adapter
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Robot configuration
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.bridge = CvBridge()
|
||||||
|
self.subscribers = {}
|
||||||
|
self.publishers = {}
|
||||||
|
|
||||||
|
def init_ros_node(self, node_name: str = None) -> None:
|
||||||
|
"""Initialize ROS node"""
|
||||||
|
if node_name is None:
|
||||||
|
node_name = self.config.get('ros_node_name', 'generic_robot_node')
|
||||||
|
|
||||||
|
rospy.init_node(node_name, anonymous=True)
|
||||||
|
|
||||||
|
def create_subscriber(self, topic: str, msg_type, callback, queue_size: int = 1000, tcp_nodelay: bool = True):
|
||||||
|
"""Create a ROS subscriber"""
|
||||||
|
subscriber = rospy.Subscriber(
|
||||||
|
topic,
|
||||||
|
msg_type,
|
||||||
|
callback,
|
||||||
|
queue_size=queue_size,
|
||||||
|
tcp_nodelay=tcp_nodelay
|
||||||
|
)
|
||||||
|
return subscriber
|
||||||
|
|
||||||
|
def create_publisher(self, topic: str, msg_type, queue_size: int = 10):
|
||||||
|
"""Create a ROS publisher"""
|
||||||
|
publisher = rospy.Publisher(
|
||||||
|
topic,
|
||||||
|
msg_type,
|
||||||
|
queue_size=queue_size
|
||||||
|
)
|
||||||
|
return publisher
|
||||||
|
|
||||||
|
def log_status(self) -> None:
|
||||||
|
"""Log ROS connection status"""
|
||||||
|
rospy.loginfo("\n=== ROS订阅状态 ===")
|
||||||
|
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
|
||||||
|
rospy.loginfo("活跃的订阅者:")
|
||||||
|
for topic, sub in self.subscribers.items():
|
||||||
|
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
|
||||||
|
rospy.loginfo("=================")
|
||||||
|
|
||||||
|
|
||||||
|
class RobotSensors:
|
||||||
|
"""Management of robot sensors (cameras, depth sensors)"""
|
||||||
|
|
||||||
|
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
|
||||||
|
"""
|
||||||
|
Initialize robot sensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Robot configuration
|
||||||
|
ros_adapter: ROS communication adapter
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.ros_adapter = ros_adapter
|
||||||
|
self.bridge = ros_adapter.bridge
|
||||||
|
|
||||||
|
# Camera data
|
||||||
|
self.cameras = config.get('cameras', {})
|
||||||
|
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
|
||||||
|
|
||||||
|
# Depth data
|
||||||
|
self.use_depth_image = config.get('use_depth_image', False)
|
||||||
|
if self.use_depth_image:
|
||||||
|
self.sync_depth_queues = {
|
||||||
|
name: deque(maxlen=2000)
|
||||||
|
for name, cam in self.cameras.items()
|
||||||
|
if 'depth_topic_name' in cam
|
||||||
|
}
|
||||||
|
|
||||||
|
# Robot base data
|
||||||
|
self.use_robot_base = config.get('use_robot_base', False)
|
||||||
|
if self.use_robot_base:
|
||||||
|
self.sync_base_queue = deque(maxlen=2000)
|
||||||
|
|
||||||
|
def setup_subscribers(self) -> None:
|
||||||
|
"""Set up ROS subscribers for sensors"""
|
||||||
|
self._setup_camera_subscribers()
|
||||||
|
if self.use_robot_base:
|
||||||
|
self._setup_base_subscriber()
|
||||||
|
|
||||||
|
def _setup_camera_subscribers(self) -> None:
|
||||||
|
"""Set up camera subscribers"""
|
||||||
|
for cam_name, cam_config in self.cameras.items():
|
||||||
|
if 'img_topic_name' in cam_config:
|
||||||
|
self.ros_adapter.subscribers[f"camera_{cam_name}"] = self.ros_adapter.create_subscriber(
|
||||||
|
cam_config['img_topic_name'],
|
||||||
|
Image,
|
||||||
|
self._make_camera_callback(cam_name, is_depth=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.use_depth_image and 'depth_topic_name' in cam_config:
|
||||||
|
self.ros_adapter.subscribers[f"depth_{cam_name}"] = self.ros_adapter.create_subscriber(
|
||||||
|
cam_config['depth_topic_name'],
|
||||||
|
Image,
|
||||||
|
self._make_camera_callback(cam_name, is_depth=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _setup_base_subscriber(self) -> None:
|
||||||
|
"""Set up base subscriber"""
|
||||||
|
if 'robot_base' in self.config.config:
|
||||||
|
self.ros_adapter.subscribers['base'] = self.ros_adapter.create_subscriber(
|
||||||
|
self.config.get('robot_base')['topic_name'],
|
||||||
|
Odometry,
|
||||||
|
self.robot_base_callback
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
|
||||||
|
"""Generate camera callback factory method"""
|
||||||
|
def callback(msg):
|
||||||
|
try:
|
||||||
|
target_queue = (
|
||||||
|
self.sync_depth_queues[cam_name]
|
||||||
|
if is_depth
|
||||||
|
else self.sync_img_queues[cam_name]
|
||||||
|
)
|
||||||
|
if len(target_queue) >= 2000:
|
||||||
|
target_queue.popleft()
|
||||||
|
target_queue.append(msg)
|
||||||
|
except Exception as e:
|
||||||
|
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def robot_base_callback(self, msg):
|
||||||
|
"""Base callback default implementation"""
|
||||||
|
if len(self.sync_base_queue) >= 2000:
|
||||||
|
self.sync_base_queue.popleft()
|
||||||
|
self.sync_base_queue.append(msg)
|
||||||
|
|
||||||
|
def init_features(self) -> Dict[str, Any]:
|
||||||
|
"""Initialize sensor features"""
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
# Initialize camera features
|
||||||
|
self._init_camera_features(features)
|
||||||
|
|
||||||
|
# Initialize base features (if enabled)
|
||||||
|
if self.use_robot_base:
|
||||||
|
self._init_base_features(features)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def _init_camera_features(self, features: Dict[str, Any]) -> None:
|
||||||
|
"""Process all camera features"""
|
||||||
|
for cam_name, cam_config in self.cameras.items():
|
||||||
|
# Regular images
|
||||||
|
features[f"observation.images.{cam_name}"] = {
|
||||||
|
"dtype": "video" if self.config.get("video", False) else "image",
|
||||||
|
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
|
||||||
|
"names": ["height", "width", "channel"],
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.get("use_depth_image", False):
|
||||||
|
features[f"observation.images.depth_{cam_name}"] = {
|
||||||
|
"dtype": "uint16",
|
||||||
|
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
|
||||||
|
"names": ["height", "width"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _init_base_features(self, features: Dict[str, Any]) -> None:
|
||||||
|
"""Process base features"""
|
||||||
|
features["observation.base_vel"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (2,),
|
||||||
|
"names": ["linear_x", "angular_z"]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RobotActuators:
|
||||||
|
"""Management of robot actuators (arms, base)"""
|
||||||
|
|
||||||
|
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
|
||||||
|
"""
|
||||||
|
Initialize robot actuators
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Robot configuration
|
||||||
|
ros_adapter: ROS communication adapter
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.ros_adapter = ros_adapter
|
||||||
|
|
||||||
|
# Arm data
|
||||||
|
self.arms = config.get('arm', {})
|
||||||
|
if config.get('control_type', '') != 'record':
|
||||||
|
# If not in record mode, only initialize puppet arm queues
|
||||||
|
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
|
||||||
|
else:
|
||||||
|
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
|
||||||
|
|
||||||
|
def setup_subscribers_publishers(self) -> None:
|
||||||
|
"""Set up ROS subscribers and publishers for actuators"""
|
||||||
|
self._setup_arm_subscribers_publishers()
|
||||||
|
|
||||||
|
def _setup_arm_subscribers_publishers(self) -> None:
|
||||||
|
"""Set up arm subscribers and publishers"""
|
||||||
|
# When in record mode, subscribe to both master and puppet arms
|
||||||
|
# Otherwise only subscribe to puppet arms, but publish to master arms
|
||||||
|
if self.config.get('control_type', '') == 'record':
|
||||||
|
for arm_name, arm_config in self.arms.items():
|
||||||
|
if 'topic_name' in arm_config:
|
||||||
|
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
|
||||||
|
arm_config['topic_name'],
|
||||||
|
JointState,
|
||||||
|
self._make_arm_callback(arm_name)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for arm_name, arm_config in self.arms.items():
|
||||||
|
if 'puppet' in arm_name:
|
||||||
|
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
|
||||||
|
arm_config['topic_name'],
|
||||||
|
JointState,
|
||||||
|
self._make_arm_callback(arm_name)
|
||||||
|
)
|
||||||
|
if 'master' in arm_name:
|
||||||
|
self.ros_adapter.publishers[f"arm_{arm_name}"] = self.ros_adapter.create_publisher(
|
||||||
|
arm_config['topic_name'],
|
||||||
|
JointState
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_arm_callback(self, arm_name: str):
|
||||||
|
"""Generate arm callback factory method"""
|
||||||
|
def callback(msg):
|
||||||
|
try:
|
||||||
|
if len(self.sync_arm_queues[arm_name]) >= 2000:
|
||||||
|
self.sync_arm_queues[arm_name].popleft()
|
||||||
|
self.sync_arm_queues[arm_name].append(msg)
|
||||||
|
except Exception as e:
|
||||||
|
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
|
||||||
|
return callback
|
||||||
|
|
||||||
|
def init_features(self) -> Dict[str, Any]:
|
||||||
|
"""Initialize actuator features"""
|
||||||
|
features = {}
|
||||||
|
|
||||||
|
# Initialize arm features
|
||||||
|
self._init_state_features(features)
|
||||||
|
self._init_action_features(features)
|
||||||
|
|
||||||
|
return features
|
||||||
|
|
||||||
|
def _init_state_features(self, features: Dict[str, Any]) -> None:
|
||||||
|
"""Initialize state features"""
|
||||||
|
state = self.config.get('state', {})
|
||||||
|
# State features
|
||||||
|
features["observation.state"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(state.get('motors', "")),),
|
||||||
|
"names": {"motors": state.get('motors', "")}
|
||||||
|
}
|
||||||
|
|
||||||
|
# if self.config.get('velocity'):
|
||||||
|
# velocity = self.config.get('velocity', "")
|
||||||
|
# features["observation.velocity"] = {
|
||||||
|
# "dtype": "float32",
|
||||||
|
# "shape": (len(velocity.get('motors', "")),),
|
||||||
|
# "names": {"motors": velocity.get('motors', "")}
|
||||||
|
# }
|
||||||
|
|
||||||
|
# if self.config.get('effort'):
|
||||||
|
# effort = self.config.get('effort', "")
|
||||||
|
# features["observation.effort"] = {
|
||||||
|
# "dtype": "float32",
|
||||||
|
# "shape": (len(effort.get('motors', "")),),
|
||||||
|
# "names": {"motors": effort.get('motors', "")}
|
||||||
|
# }
|
||||||
|
|
||||||
|
def _init_action_features(self, features: Dict[str, Any]) -> None:
|
||||||
|
"""Initialize action features"""
|
||||||
|
action = self.config.get('action', {})
|
||||||
|
features["action"] = {
|
||||||
|
"dtype": "float32",
|
||||||
|
"shape": (len(action.get('motors', "")),),
|
||||||
|
"names": {"motors": action.get('motors', "")}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class RobotDataManager:
|
||||||
|
"""Management of robot data collection and synchronization"""
|
||||||
|
|
||||||
|
def __init__(self, config: RobotConfig, sensors: RobotSensors, actuators: RobotActuators):
|
||||||
|
"""
|
||||||
|
Initialize robot data manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Robot configuration
|
||||||
|
sensors: Robot sensors component
|
||||||
|
actuators: Robot actuators component
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
self.sensors = sensors
|
||||||
|
self.actuators = actuators
|
||||||
|
|
||||||
|
def warmup(self, timeout: float = 30.0) -> bool:
|
||||||
|
"""
|
||||||
|
Wait until all data queues have sufficient messages
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if warmup succeeded, False if timed out
|
||||||
|
"""
|
||||||
|
start_time = rospy.Time.now().to_sec()
|
||||||
|
rate = rospy.Rate(10) # Check at 10Hz
|
||||||
|
|
||||||
|
rospy.loginfo("Starting warmup process...")
|
||||||
|
|
||||||
|
while not rospy.is_shutdown():
|
||||||
|
# Check if timeout has been reached
|
||||||
|
current_time = rospy.Time.now().to_sec()
|
||||||
|
if current_time - start_time > timeout:
|
||||||
|
rospy.logwarn("Warmup timed out before all queues were filled")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check all required queues
|
||||||
|
all_ready = True
|
||||||
|
|
||||||
|
# Check camera image queues
|
||||||
|
rospy.loginfo(f"Nums of camera is {len(self.sensors.cameras)}")
|
||||||
|
for cam_name in self.sensors.cameras:
|
||||||
|
if len(self.sensors.sync_img_queues[cam_name]) < 200:
|
||||||
|
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)")
|
||||||
|
all_ready = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check depth queues if enabled
|
||||||
|
if self.sensors.use_depth_image:
|
||||||
|
for cam_name in self.sensors.sync_depth_queues:
|
||||||
|
if len(self.sensors.sync_depth_queues[cam_name]) < 200:
|
||||||
|
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)")
|
||||||
|
all_ready = False
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check base queue if enabled
|
||||||
|
if self.sensors.use_robot_base:
|
||||||
|
if len(self.sensors.sync_base_queue) < 20:
|
||||||
|
rospy.loginfo(f"Waiting for base (current: {len(self.sensors.sync_base_queue)}/20)")
|
||||||
|
all_ready = False
|
||||||
|
|
||||||
|
# If all queues are ready, return success
|
||||||
|
if all_ready:
|
||||||
|
rospy.loginfo("Warmup completed successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
rate.sleep()
|
||||||
|
|
||||||
|
return False
|
||||||
136
lerobot_aloha/common/rosrobot.py
Normal file
136
lerobot_aloha/common/rosrobot.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
import yaml
|
||||||
|
from typing import Dict, Any, Optional, List
|
||||||
|
import argparse
|
||||||
|
from .robot_components import RobotConfig, RosAdapter, RobotSensors, RobotActuators, RobotDataManager
|
||||||
|
|
||||||
|
|
||||||
|
class Robot:
|
||||||
|
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
|
||||||
|
"""
|
||||||
|
机器人基类,处理通用初始化逻辑
|
||||||
|
Args:
|
||||||
|
config_file: YAML配置文件路径
|
||||||
|
args: 运行时参数
|
||||||
|
"""
|
||||||
|
# 初始化组件
|
||||||
|
self.config = RobotConfig(config_file)
|
||||||
|
self.config.merge_runtime_args(args)
|
||||||
|
self.ros_adapter = RosAdapter(self.config)
|
||||||
|
self.sensors = RobotSensors(self.config, self.ros_adapter)
|
||||||
|
self.actuators = RobotActuators(self.config, self.ros_adapter)
|
||||||
|
self.data_manager = RobotDataManager(self.config, self.sensors, self.actuators)
|
||||||
|
|
||||||
|
# 初始化ROS和特征
|
||||||
|
self.init_ros()
|
||||||
|
self.init_features()
|
||||||
|
self.warmup()
|
||||||
|
|
||||||
|
def get(self, key: str, default=None) -> Any:
|
||||||
|
"""获取配置值"""
|
||||||
|
return self.config.get(key, default)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def bridge(self):
|
||||||
|
"""获取CV桥接器"""
|
||||||
|
return self.ros_adapter.bridge
|
||||||
|
|
||||||
|
@property
|
||||||
|
def subscribers(self):
|
||||||
|
"""获取订阅者"""
|
||||||
|
return self.ros_adapter.subscribers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def publishers(self):
|
||||||
|
"""获取发布者"""
|
||||||
|
return self.ros_adapter.publishers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cameras(self):
|
||||||
|
"""获取相机配置"""
|
||||||
|
return self.sensors.cameras
|
||||||
|
|
||||||
|
@property
|
||||||
|
def arms(self):
|
||||||
|
"""获取机械臂配置"""
|
||||||
|
return self.actuators.arms
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_img_queues(self):
|
||||||
|
"""获取图像队列"""
|
||||||
|
return self.sensors.sync_img_queues
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_depth_queues(self):
|
||||||
|
"""获取深度图像队列"""
|
||||||
|
return self.sensors.sync_depth_queues if hasattr(self.sensors, 'sync_depth_queues') else {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_arm_queues(self):
|
||||||
|
"""获取机械臂队列"""
|
||||||
|
return self.actuators.sync_arm_queues
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sync_base_queue(self):
|
||||||
|
"""获取基座队列"""
|
||||||
|
return self.sensors.sync_base_queue if hasattr(self.sensors, 'sync_base_queue') else None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_depth_image(self):
|
||||||
|
"""是否使用深度图像"""
|
||||||
|
return self.sensors.use_depth_image
|
||||||
|
|
||||||
|
@property
|
||||||
|
def use_robot_base(self):
|
||||||
|
"""是否使用机器人基座"""
|
||||||
|
return self.sensors.use_robot_base
|
||||||
|
|
||||||
|
def init_ros(self) -> None:
|
||||||
|
"""初始化ROS订阅的模板方法"""
|
||||||
|
self.ros_adapter.init_ros_node()
|
||||||
|
|
||||||
|
# 设置传感器和执行器的订阅者和发布者
|
||||||
|
self.sensors.setup_subscribers()
|
||||||
|
self.actuators.setup_subscribers_publishers()
|
||||||
|
|
||||||
|
# 记录ROS状态
|
||||||
|
self.ros_adapter.log_status()
|
||||||
|
|
||||||
|
def init_features(self):
|
||||||
|
"""
|
||||||
|
根据YAML配置自动生成features结构
|
||||||
|
"""
|
||||||
|
# 合并传感器和执行器的特征
|
||||||
|
self.features = {}
|
||||||
|
self.features.update(self.sensors.init_features())
|
||||||
|
self.features.update(self.actuators.init_features())
|
||||||
|
|
||||||
|
import pprint
|
||||||
|
pprint.pprint(self.features, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def warmup(self, timeout: float = 30.0) -> bool:
|
||||||
|
"""Wait until all data queues have at least 20 messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Maximum time to wait in seconds before giving up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if warmup succeeded, False if timed out
|
||||||
|
"""
|
||||||
|
return self.data_manager.warmup(timeout)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def get_frame(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""获取同步帧数据的模板方法"""
|
||||||
|
raise NotImplementedError("Subclasses must implement get_frame()")
|
||||||
|
|
||||||
|
def process(self) -> tuple:
|
||||||
|
"""主处理循环的模板方法"""
|
||||||
|
raise NotImplementedError("Subclasses must implement process()")
|
||||||
59
lerobot_aloha/common/rosrobot_factory.py
Normal file
59
lerobot_aloha/common/rosrobot_factory.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
import yaml
|
||||||
|
import argparse
|
||||||
|
from typing import Dict, List, Any, Optional, Type
|
||||||
|
from .rosrobot import Robot
|
||||||
|
from .agilex_robot import AgilexRobot
|
||||||
|
|
||||||
|
|
||||||
|
class RobotFactory:
|
||||||
|
"""Factory for creating robot instances based on configuration"""
|
||||||
|
|
||||||
|
# 注册表,用于存储可用的机器人类型
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, robot_type: str, robot_class: Type[Robot]) -> None:
|
||||||
|
"""
|
||||||
|
注册新的机器人类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot_type: 机器人类型标识符
|
||||||
|
robot_class: 机器人类实现
|
||||||
|
"""
|
||||||
|
cls._registry[robot_type] = robot_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
|
||||||
|
"""
|
||||||
|
根据配置文件自动创建合适的机器人实例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_file: 配置文件路径
|
||||||
|
args: 运行时参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Robot: 创建的机器人实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果指定的机器人类型不受支持
|
||||||
|
"""
|
||||||
|
with open(config_file, 'r') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
robot_type = config.get('robot_type', 'agilex')
|
||||||
|
|
||||||
|
# 如果注册表为空,注册默认机器人类型
|
||||||
|
if not cls._registry:
|
||||||
|
cls.register('agilex', AgilexRobot)
|
||||||
|
cls.register('aloha_agilex', AgilexRobot) # 别名支持
|
||||||
|
|
||||||
|
# 从注册表中查找机器人类
|
||||||
|
if robot_type in cls._registry:
|
||||||
|
return cls._registry[robot_type](config_file, args)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported robot type: {robot_type}. Available types: {list(cls._registry.keys())}")
|
||||||
|
|
||||||
|
|
||||||
|
# 注册可用的机器人类型
|
||||||
|
RobotFactory.register('agilex', AgilexRobot)
|
||||||
|
RobotFactory.register('aloha_agilex', AgilexRobot) # 别名支持
|
||||||
12
lerobot_aloha/common/utils/__init__.py
Normal file
12
lerobot_aloha/common/utils/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Import utility functions for easy access
|
||||||
|
from .control_utils import (
|
||||||
|
predict_action,
|
||||||
|
control_loop,
|
||||||
|
init_keyboard_listener,
|
||||||
|
stop_recording,
|
||||||
|
record_episode,
|
||||||
|
is_headless,
|
||||||
|
busy_wait
|
||||||
|
)
|
||||||
|
from .data_utils import record
|
||||||
|
from .replay_utils import replay
|
||||||
BIN
lerobot_aloha/common/utils/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
lerobot_aloha/common/utils/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
312
lerobot_aloha/common/utils/control_utils.py
Normal file
312
lerobot_aloha/common/utils/control_utils.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import rospy
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from copy import copy
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.utils.utils import get_safe_torch_device
|
||||||
|
|
||||||
|
|
||||||
|
def is_headless():
|
||||||
|
"""
|
||||||
|
Check if the environment is headless (no display available).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the environment is headless, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import tkinter as tk
|
||||||
|
root = tk.Tk()
|
||||||
|
root.withdraw()
|
||||||
|
root.update()
|
||||||
|
root.destroy()
|
||||||
|
return False
|
||||||
|
except:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def predict_action(observation, policy, device, use_amp):
|
||||||
|
"""
|
||||||
|
Predict action based on observation using the policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
observation: Current observation
|
||||||
|
policy: Policy model
|
||||||
|
device: Torch device
|
||||||
|
use_amp: Whether to use automatic mixed precision
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Predicted action
|
||||||
|
"""
|
||||||
|
observation = copy(observation)
|
||||||
|
with (
|
||||||
|
torch.inference_mode(),
|
||||||
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||||
|
):
|
||||||
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||||
|
for name in observation:
|
||||||
|
if "image" in name:
|
||||||
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
# Compute the next action with the policy
|
||||||
|
# based on the current observation
|
||||||
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
action = action.squeeze(0)
|
||||||
|
|
||||||
|
# Move to cpu, if not already the case
|
||||||
|
action = action.to("cpu")
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
|
||||||
|
def control_loop(
|
||||||
|
robot,
|
||||||
|
control_time_s=None,
|
||||||
|
teleoperate=False,
|
||||||
|
display_cameras=False,
|
||||||
|
dataset: LeRobotDataset | None = None,
|
||||||
|
events=None,
|
||||||
|
policy = None,
|
||||||
|
fps: int | None = None,
|
||||||
|
single_task: str | None = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Main control loop for robot operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: Robot instance
|
||||||
|
control_time_s: Control time in seconds
|
||||||
|
teleoperate: Whether to use teleoperation
|
||||||
|
display_cameras: Whether to display camera feeds
|
||||||
|
dataset: Dataset for recording
|
||||||
|
events: Event dictionary
|
||||||
|
policy: Policy model
|
||||||
|
fps: Frames per second
|
||||||
|
single_task: Task name
|
||||||
|
"""
|
||||||
|
if events is None:
|
||||||
|
events = {"exit_early": False}
|
||||||
|
|
||||||
|
if control_time_s is None:
|
||||||
|
control_time_s = float("inf")
|
||||||
|
|
||||||
|
if dataset is not None and single_task is None:
|
||||||
|
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||||
|
|
||||||
|
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||||
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
rate = rospy.Rate(fps)
|
||||||
|
print_flag = True
|
||||||
|
while timestamp < control_time_s and not rospy.is_shutdown():
|
||||||
|
start_loop_t = time.perf_counter()
|
||||||
|
|
||||||
|
if teleoperate:
|
||||||
|
observation, action = robot.teleop_step()
|
||||||
|
if observation is None or action is None:
|
||||||
|
if print_flag:
|
||||||
|
print("sync data fail, retrying...\n")
|
||||||
|
print_flag = False
|
||||||
|
rate.sleep()
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
observation = robot.capture_observation()
|
||||||
|
if policy is not None:
|
||||||
|
pred_action = predict_action(
|
||||||
|
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||||
|
)
|
||||||
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
|
# so action actually sent is saved in the dataset.
|
||||||
|
action = robot.send_action(pred_action)
|
||||||
|
action = {"action": action}
|
||||||
|
|
||||||
|
if dataset is not None:
|
||||||
|
frame = {**observation, **action, "task": single_task}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
if display_cameras and not is_headless():
|
||||||
|
image_keys = [key for key in observation if "image" in key and "depth" not in key]
|
||||||
|
num_images = len(image_keys)
|
||||||
|
|
||||||
|
if num_images > 0:
|
||||||
|
# 设置每个图像的显示尺寸
|
||||||
|
display_width = 640 # 更小的宽度
|
||||||
|
display_height = 480 # 更小的高度
|
||||||
|
|
||||||
|
# 确定网格布局的行列数 (尽量接近正方形布局)
|
||||||
|
grid_cols = int(np.ceil(np.sqrt(num_images)))
|
||||||
|
grid_rows = int(np.ceil(num_images / grid_cols))
|
||||||
|
|
||||||
|
# 创建一个大的画布来容纳所有图像
|
||||||
|
canvas = np.zeros((grid_rows * display_height, grid_cols * display_width, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 在画布上放置每个图像
|
||||||
|
for idx, key in enumerate(image_keys):
|
||||||
|
row = idx // grid_cols
|
||||||
|
col = idx % grid_cols
|
||||||
|
|
||||||
|
# 获取图像并转换为BGR
|
||||||
|
image = observation[key].numpy()
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# 调整图像大小
|
||||||
|
resized_image = cv2.resize(image, (display_width, display_height))
|
||||||
|
|
||||||
|
# 计算在画布上的位置
|
||||||
|
y_start = row * display_height
|
||||||
|
y_end = y_start + display_height
|
||||||
|
x_start = col * display_width
|
||||||
|
x_end = x_start + display_width
|
||||||
|
|
||||||
|
# 将图像放置到画布上
|
||||||
|
canvas[y_start:y_end, x_start:x_end] = resized_image
|
||||||
|
|
||||||
|
# 添加图像标题
|
||||||
|
title_position = (x_start + 5, y_start + 15)
|
||||||
|
cv2.putText(canvas, key, title_position, cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
|
||||||
|
|
||||||
|
# 显示合并后的画布
|
||||||
|
cv2.imshow("Camera Views", canvas)
|
||||||
|
cv2.waitKey(1)
|
||||||
|
|
||||||
|
|
||||||
|
if fps is not None:
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_loop_t
|
||||||
|
|
||||||
|
timestamp = time.perf_counter() - start_episode_t
|
||||||
|
if events["exit_early"]:
|
||||||
|
events["exit_early"] = False
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def init_keyboard_listener():
|
||||||
|
"""
|
||||||
|
Initialize keyboard listener for control events.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (listener, events) - Keyboard listener and events dictionary
|
||||||
|
"""
|
||||||
|
# Allow to exit early while recording an episode or resetting the environment,
|
||||||
|
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||||
|
# to allow your terminal to monitor keyboard events.
|
||||||
|
events = {}
|
||||||
|
events["exit_early"] = False
|
||||||
|
events["record_start"] = False
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["stop_recording"] = False
|
||||||
|
|
||||||
|
if is_headless():
|
||||||
|
logging.warning(
|
||||||
|
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||||
|
)
|
||||||
|
listener = None
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
# Only import pynput if not in a headless environment
|
||||||
|
from pynput import keyboard
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
try:
|
||||||
|
if key == keyboard.Key.right:
|
||||||
|
print("Right arrow key pressed. Exiting loop...")
|
||||||
|
events["exit_early"] = True
|
||||||
|
events["record_start"] = False
|
||||||
|
elif key == keyboard.Key.left:
|
||||||
|
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||||
|
events["rerecord_episode"] = True
|
||||||
|
events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.esc:
|
||||||
|
print("Escape key pressed. Stopping data recording...")
|
||||||
|
events["stop_recording"] = True
|
||||||
|
events["exit_early"] = True
|
||||||
|
elif key == keyboard.Key.up:
|
||||||
|
print("Up arrow pressed. Start data recording...")
|
||||||
|
events["record_start"] = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error handling key press: {e}")
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
|
||||||
|
return listener, events
|
||||||
|
|
||||||
|
|
||||||
|
def stop_recording(robot, listener, display_cameras):
|
||||||
|
"""
|
||||||
|
Stop recording and clean up resources.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: Robot instance
|
||||||
|
listener: Keyboard listener
|
||||||
|
display_cameras: Whether cameras are being displayed
|
||||||
|
"""
|
||||||
|
if not is_headless():
|
||||||
|
if listener is not None:
|
||||||
|
listener.stop()
|
||||||
|
|
||||||
|
if display_cameras:
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
|
||||||
|
def record_episode(
|
||||||
|
robot,
|
||||||
|
dataset,
|
||||||
|
events,
|
||||||
|
episode_time_s,
|
||||||
|
display_cameras,
|
||||||
|
policy,
|
||||||
|
fps,
|
||||||
|
single_task,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Record a single episode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: Robot instance
|
||||||
|
dataset: Dataset for recording
|
||||||
|
events: Event dictionary
|
||||||
|
episode_time_s: Episode time in seconds
|
||||||
|
display_cameras: Whether to display camera feeds
|
||||||
|
policy: Policy model
|
||||||
|
fps: Frames per second
|
||||||
|
single_task: Task name
|
||||||
|
"""
|
||||||
|
control_loop(
|
||||||
|
robot=robot,
|
||||||
|
control_time_s=episode_time_s,
|
||||||
|
display_cameras=display_cameras,
|
||||||
|
dataset=dataset,
|
||||||
|
events=events,
|
||||||
|
policy=policy,
|
||||||
|
fps=fps,
|
||||||
|
teleoperate=policy is None,
|
||||||
|
single_task=single_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def busy_wait(seconds):
|
||||||
|
"""
|
||||||
|
Busy wait for a specified number of seconds.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seconds: Number of seconds to wait
|
||||||
|
"""
|
||||||
|
if seconds <= 0:
|
||||||
|
return
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
while time.perf_counter() - start_time < seconds:
|
||||||
|
pass
|
||||||
105
lerobot_aloha/common/utils/data_utils.py
Normal file
105
lerobot_aloha/common/utils/data_utils.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pprint import pprint
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.policies.factory import make_policy
|
||||||
|
from lerobot.common.utils.utils import log_say, has_method
|
||||||
|
from common.utils.control_utils import init_keyboard_listener, stop_recording, record_episode
|
||||||
|
|
||||||
|
|
||||||
|
def record(
|
||||||
|
robot,
|
||||||
|
cfg
|
||||||
|
) -> LeRobotDataset:
|
||||||
|
"""
|
||||||
|
Record robot data according to configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: Robot instance
|
||||||
|
cfg: Configuration object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LeRobotDataset: Dataset with recorded episodes
|
||||||
|
"""
|
||||||
|
# Initialize or load dataset
|
||||||
|
if cfg.resume:
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
cfg.repo_id,
|
||||||
|
root=cfg.root,
|
||||||
|
)
|
||||||
|
if len(robot.cameras) > 0:
|
||||||
|
dataset.start_image_writer(
|
||||||
|
num_processes=cfg.num_image_writer_processes,
|
||||||
|
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Create empty dataset or load existing saved episodes
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
cfg.repo_id,
|
||||||
|
cfg.fps,
|
||||||
|
root=cfg.root,
|
||||||
|
robot=None,
|
||||||
|
features=robot.features,
|
||||||
|
use_videos=cfg.video,
|
||||||
|
image_writer_processes=cfg.num_image_writer_processes,
|
||||||
|
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load pretrained policy
|
||||||
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
|
|
||||||
|
# Initialize keyboard listener
|
||||||
|
listener, events = init_keyboard_listener()
|
||||||
|
|
||||||
|
# Print recording instructions
|
||||||
|
print()
|
||||||
|
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
||||||
|
|
||||||
|
# Record episodes
|
||||||
|
recorded_episodes = 0
|
||||||
|
while True:
|
||||||
|
if recorded_episodes >= cfg.num_episodes:
|
||||||
|
break
|
||||||
|
|
||||||
|
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||||
|
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
||||||
|
record_episode(
|
||||||
|
robot=robot,
|
||||||
|
dataset=dataset,
|
||||||
|
events=events,
|
||||||
|
episode_time_s=cfg.episode_time_s,
|
||||||
|
display_cameras=cfg.display_cameras,
|
||||||
|
policy=policy,
|
||||||
|
fps=cfg.fps,
|
||||||
|
single_task=cfg.single_task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip reset for the last episode to be recorded
|
||||||
|
if not events["stop_recording"] and (
|
||||||
|
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||||||
|
):
|
||||||
|
log_say("Reset the environment", cfg.play_sounds)
|
||||||
|
pprint("Reset the environment, stop recording")
|
||||||
|
|
||||||
|
if events["rerecord_episode"]:
|
||||||
|
log_say("Re-record episode", cfg.play_sounds)
|
||||||
|
pprint("Re-record episode")
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
recorded_episodes += 1
|
||||||
|
|
||||||
|
if events["stop_recording"]:
|
||||||
|
break
|
||||||
|
|
||||||
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||||
|
stop_recording(robot, listener, cfg.display_cameras)
|
||||||
|
|
||||||
|
if cfg.push_to_hub:
|
||||||
|
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||||
|
|
||||||
|
log_say("Exiting", cfg.play_sounds)
|
||||||
|
return dataset
|
||||||
32
lerobot_aloha/common/utils/replay_utils.py
Normal file
32
lerobot_aloha/common/utils/replay_utils.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import time
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from common.utils.control_utils import busy_wait
|
||||||
|
|
||||||
|
|
||||||
|
def replay(
|
||||||
|
robot,
|
||||||
|
cfg,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Replay recorded robot data according to configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
robot: Robot instance
|
||||||
|
cfg: Configuration object
|
||||||
|
"""
|
||||||
|
# Load dataset
|
||||||
|
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||||
|
actions = dataset.hf_dataset.select_columns("action")
|
||||||
|
|
||||||
|
print(f"Replaying episode {cfg.episode} from dataset {cfg.repo_id}")
|
||||||
|
print(f"Total frames: {dataset.num_frames}")
|
||||||
|
|
||||||
|
# Replay each frame
|
||||||
|
for idx in range(dataset.num_frames):
|
||||||
|
start_episode_t = time.perf_counter()
|
||||||
|
|
||||||
|
action = actions[idx]["action"]
|
||||||
|
robot.send_action(action)
|
||||||
|
|
||||||
|
dt_s = time.perf_counter() - start_episode_t
|
||||||
|
busy_wait(1 / cfg.fps - dt_s)
|
||||||
@@ -1,12 +1,20 @@
|
|||||||
robot_type: aloha_agilex
|
robot_type: aloha_agilex
|
||||||
ros_node_name: record_episodes
|
ros_node_name: record_episodes
|
||||||
cameras:
|
cameras:
|
||||||
cam_front:
|
cam_high:
|
||||||
|
# img_topic_name: /camera/color/image_raw
|
||||||
|
# depth_topic_name: /camera/depth/image_rect_raw
|
||||||
img_topic_name: /camera_f/color/image_raw
|
img_topic_name: /camera_f/color/image_raw
|
||||||
depth_topic_name: /camera_f/depth/image_raw
|
depth_topic_name: /camera_f/depth/image_raw
|
||||||
|
rgb_shape: [480, 640, 3]
|
||||||
width: 480
|
width: 480
|
||||||
height: 640
|
height: 640
|
||||||
rgb_shape: [480, 640, 3]
|
# cam_front:
|
||||||
|
# img_topic_name: /camera_f/color/image_raw
|
||||||
|
# depth_topic_name: /camera_f/depth/image_raw
|
||||||
|
# width: 480
|
||||||
|
# height: 640
|
||||||
|
# rgb_shape: [480, 640, 3]
|
||||||
cam_left:
|
cam_left:
|
||||||
img_topic_name: /camera_l/color/image_raw
|
img_topic_name: /camera_l/color/image_raw
|
||||||
depth_topic_name: /camera_l/depth/image_raw
|
depth_topic_name: /camera_l/depth/image_raw
|
||||||
@@ -20,6 +28,7 @@ cameras:
|
|||||||
width: 480
|
width: 480
|
||||||
height: 640
|
height: 640
|
||||||
|
|
||||||
|
|
||||||
arm:
|
arm:
|
||||||
master_left:
|
master_left:
|
||||||
topic_name: /master/joint_left
|
topic_name: /master/joint_left
|
||||||
@@ -85,41 +94,41 @@ state:
|
|||||||
"right_none"
|
"right_none"
|
||||||
]
|
]
|
||||||
|
|
||||||
velocity:
|
# velocity:
|
||||||
motors: [
|
# motors: [
|
||||||
"left_joint0",
|
# "left_joint0",
|
||||||
"left_joint1",
|
# "left_joint1",
|
||||||
"left_joint2",
|
# "left_joint2",
|
||||||
"left_joint3",
|
# "left_joint3",
|
||||||
"left_joint4",
|
# "left_joint4",
|
||||||
"left_joint5",
|
# "left_joint5",
|
||||||
"left_none",
|
# "left_none",
|
||||||
"right_joint0",
|
# "right_joint0",
|
||||||
"right_joint1",
|
# "right_joint1",
|
||||||
"right_joint2",
|
# "right_joint2",
|
||||||
"right_joint3",
|
# "right_joint3",
|
||||||
"right_joint4",
|
# "right_joint4",
|
||||||
"right_joint5",
|
# "right_joint5",
|
||||||
"right_none"
|
# "right_none"
|
||||||
]
|
# ]
|
||||||
|
|
||||||
effort:
|
# effort:
|
||||||
motors: [
|
# motors: [
|
||||||
"left_joint0",
|
# "left_joint0",
|
||||||
"left_joint1",
|
# "left_joint1",
|
||||||
"left_joint2",
|
# "left_joint2",
|
||||||
"left_joint3",
|
# "left_joint3",
|
||||||
"left_joint4",
|
# "left_joint4",
|
||||||
"left_joint5",
|
# "left_joint5",
|
||||||
"left_none",
|
# "left_none",
|
||||||
"right_joint0",
|
# "right_joint0",
|
||||||
"right_joint1",
|
# "right_joint1",
|
||||||
"right_joint2",
|
# "right_joint2",
|
||||||
"right_joint3",
|
# "right_joint3",
|
||||||
"right_joint4",
|
# "right_joint4",
|
||||||
"right_joint5",
|
# "right_joint5",
|
||||||
"right_none"
|
# "right_none"
|
||||||
]
|
# ]
|
||||||
|
|
||||||
action:
|
action:
|
||||||
motors: [
|
motors: [
|
||||||
239
lerobot_aloha/gui_app.py
Normal file
239
lerobot_aloha/gui_app.py
Normal file
@@ -0,0 +1,239 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||||||
|
QHBoxLayout, QLabel, QLineEdit, QSpinBox, QCheckBox,
|
||||||
|
QPushButton, QGroupBox, QTabWidget, QScrollArea, QGridLayout)
|
||||||
|
from PyQt5.QtCore import Qt, QTimer
|
||||||
|
import cv2
|
||||||
|
from PyQt5.QtGui import QImage, QPixmap
|
||||||
|
from main import get_arguments, control_robot
|
||||||
|
|
||||||
|
class ConfigGroup(QGroupBox):
|
||||||
|
"""Group of configuration widgets"""
|
||||||
|
def __init__(self, title, parent=None):
|
||||||
|
super().__init__(title, parent)
|
||||||
|
self.layout = QVBoxLayout()
|
||||||
|
self.setLayout(self.layout)
|
||||||
|
|
||||||
|
def add_config(self, name, value, widget_type="lineedit"):
|
||||||
|
"""Add a configuration widget"""
|
||||||
|
row = QHBoxLayout()
|
||||||
|
label = QLabel(name)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
widget = QCheckBox()
|
||||||
|
widget.setChecked(value)
|
||||||
|
elif isinstance(value, int):
|
||||||
|
widget = QSpinBox()
|
||||||
|
widget.setRange(0, 999999)
|
||||||
|
widget.setValue(value)
|
||||||
|
else: # string or other
|
||||||
|
widget = QLineEdit(str(value))
|
||||||
|
|
||||||
|
row.addWidget(label)
|
||||||
|
row.addWidget(widget)
|
||||||
|
self.layout.addLayout(row)
|
||||||
|
return widget
|
||||||
|
|
||||||
|
class MainWindow(QMainWindow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.setWindowTitle("MindRobot-V1 Control GUI")
|
||||||
|
self.setGeometry(100, 100, 600, 800) # Adjusted window size
|
||||||
|
self.robot = None
|
||||||
|
|
||||||
|
# Get default arguments
|
||||||
|
self.cfg = None
|
||||||
|
|
||||||
|
# Main layout
|
||||||
|
main_widget = QWidget()
|
||||||
|
main_layout = QHBoxLayout()
|
||||||
|
main_widget.setLayout(main_layout)
|
||||||
|
|
||||||
|
# Left panel - configuration
|
||||||
|
config_scroll = QScrollArea()
|
||||||
|
config_widget = QWidget()
|
||||||
|
config_layout = QVBoxLayout()
|
||||||
|
config_widget.setLayout(config_layout)
|
||||||
|
|
||||||
|
# Add configuration groups
|
||||||
|
general_group = ConfigGroup("General Settings")
|
||||||
|
self.single_task_widget = general_group.add_config("Single Task", "")
|
||||||
|
self.single_task_widget.setMinimumWidth(300) # Wider text box
|
||||||
|
self.single_task_widget.textChanged.connect(self.update_repo_id_from_task)
|
||||||
|
self.fps_widget = general_group.add_config("FPS", 30, "spinbox")
|
||||||
|
self.resume_widget = general_group.add_config("Resume", False, "checkbox")
|
||||||
|
self.repo_id_widget = general_group.add_config("Repo ID", "")
|
||||||
|
self.repo_id_widget.textChanged.connect(self.update_root_from_repo_id)
|
||||||
|
self.original_repo_id = "" # Store original value
|
||||||
|
self.play_sounds_widget = general_group.add_config("Play Sounds", False, "checkbox")
|
||||||
|
|
||||||
|
|
||||||
|
# Config file with browse button
|
||||||
|
config_row = QHBoxLayout()
|
||||||
|
config_label = QLabel("Config File")
|
||||||
|
self.config_widget = QLineEdit("/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml")
|
||||||
|
config_browse_button = QPushButton("Browse...")
|
||||||
|
config_browse_button.clicked.connect(self.browse_config_file)
|
||||||
|
config_row.addWidget(config_label)
|
||||||
|
config_row.addWidget(self.config_widget)
|
||||||
|
config_row.addWidget(config_browse_button)
|
||||||
|
general_group.layout.addLayout(config_row)
|
||||||
|
# Root directory with browse button
|
||||||
|
root_row = QHBoxLayout()
|
||||||
|
root_label = QLabel("Root Directory")
|
||||||
|
self.root_widget = QLineEdit(str("/home/ubuntu/LYT/lerobot_aloha/datasets/"+self.repo_id_widget.text()))
|
||||||
|
browse_button = QPushButton("Browse...")
|
||||||
|
browse_button.clicked.connect(self.browse_root_directory)
|
||||||
|
root_row.addWidget(root_label)
|
||||||
|
root_row.addWidget(self.root_widget)
|
||||||
|
root_row.addWidget(browse_button)
|
||||||
|
general_group.layout.addLayout(root_row)
|
||||||
|
config_layout.addWidget(general_group)
|
||||||
|
|
||||||
|
recording_group = ConfigGroup("Recording Settings")
|
||||||
|
self.episode_widget = recording_group.add_config("Episode", 0, "spinbox")
|
||||||
|
self.num_episodes_widget = recording_group.add_config("Number of Episodes", 100, "spinbox")
|
||||||
|
self.episode_time_widget = recording_group.add_config("Episode Time (ms)", 36000, "spinbox")
|
||||||
|
self.video_widget = recording_group.add_config("Save Video", True, "checkbox")
|
||||||
|
self.display_cameras_widget = recording_group.add_config("Display Cameras", True, "checkbox")
|
||||||
|
config_layout.addWidget(recording_group)
|
||||||
|
|
||||||
|
advanced_group = ConfigGroup("Advanced Settings")
|
||||||
|
self.num_writer_processes_widget = advanced_group.add_config("Writer Processes", 0, "spinbox")
|
||||||
|
self.num_writer_threads_widget = advanced_group.add_config("Threads per Camera", 4, "spinbox")
|
||||||
|
self.use_depth_widget = advanced_group.add_config("Use Depth Image", False, "checkbox")
|
||||||
|
self.use_base_widget = advanced_group.add_config("Use Base", False, "checkbox")
|
||||||
|
self.push_to_hub_widget = advanced_group.add_config("Push to Hub", False, "checkbox")
|
||||||
|
self.policy_widget = advanced_group.add_config("Policy", None)
|
||||||
|
config_layout.addWidget(advanced_group)
|
||||||
|
|
||||||
|
# Control buttons
|
||||||
|
control_buttons = QHBoxLayout()
|
||||||
|
self.record_button = QPushButton("Record")
|
||||||
|
self.record_button.clicked.connect(self.start_recording)
|
||||||
|
control_buttons.addWidget(self.record_button)
|
||||||
|
|
||||||
|
self.stop_button = QPushButton("Stop")
|
||||||
|
self.stop_button.clicked.connect(self.stop_recording)
|
||||||
|
control_buttons.addWidget(self.stop_button)
|
||||||
|
config_layout.addLayout(control_buttons)
|
||||||
|
|
||||||
|
config_scroll.setWidget(config_widget)
|
||||||
|
config_scroll.setWidgetResizable(True)
|
||||||
|
main_layout.addWidget(config_scroll, stretch=1) # Left panel stretch factor
|
||||||
|
|
||||||
|
# Remove camera view panel completely
|
||||||
|
|
||||||
|
self.setCentralWidget(main_widget)
|
||||||
|
|
||||||
|
# Robot control flag
|
||||||
|
self.is_recording = False
|
||||||
|
|
||||||
|
|
||||||
|
def browse_config_file(self):
|
||||||
|
"""Open file dialog to select config file"""
|
||||||
|
from PyQt5.QtWidgets import QFileDialog
|
||||||
|
file_path, _ = QFileDialog.getOpenFileName(
|
||||||
|
self,
|
||||||
|
"Select Config File",
|
||||||
|
self.config_widget.text(),
|
||||||
|
"YAML Files (*.yaml *.yml)"
|
||||||
|
)
|
||||||
|
if file_path:
|
||||||
|
self.config_widget.setText(file_path)
|
||||||
|
|
||||||
|
def update_repo_id_from_task(self, text):
|
||||||
|
"""Update repo_id from single_task text, replacing spaces with underscores"""
|
||||||
|
if not hasattr(self, 'repo_id_edited') or not self.repo_id_edited:
|
||||||
|
self.repo_id_widget.setText(text.replace(" ", "_"))
|
||||||
|
|
||||||
|
def update_root_from_repo_id(self, text):
|
||||||
|
"""Update root directory based on repo_id"""
|
||||||
|
if self.root_widget:
|
||||||
|
self.root_widget.setText(f"/home/ubuntu/LYT/lerobot_aloha/datasets/{text}")
|
||||||
|
|
||||||
|
def browse_root_directory(self):
|
||||||
|
"""Open file dialog to select root directory"""
|
||||||
|
from PyQt5.QtWidgets import QFileDialog
|
||||||
|
dir_path = QFileDialog.getExistingDirectory(
|
||||||
|
self,
|
||||||
|
"Select Root Directory",
|
||||||
|
self.root_widget.text()
|
||||||
|
)
|
||||||
|
if dir_path:
|
||||||
|
self.root_widget.setText(dir_path)
|
||||||
|
|
||||||
|
def get_config_values(self):
|
||||||
|
"""Get current configuration values from UI"""
|
||||||
|
self.cfg.fps = self.fps_widget.value()
|
||||||
|
self.cfg.resume = self.resume_widget.isChecked()
|
||||||
|
self.cfg.repo_id = self.repo_id_widget.text()
|
||||||
|
self.repo_id_edited = (self.cfg.repo_id != self.original_repo_id)
|
||||||
|
self.cfg.root = self.root_widget.text()
|
||||||
|
self.cfg.episode = self.episode_widget.value()
|
||||||
|
self.cfg.num_episodes = self.num_episodes_widget.value()
|
||||||
|
self.cfg.episode_time_s = self.episode_time_widget.value()
|
||||||
|
self.cfg.video = self.video_widget.isChecked()
|
||||||
|
self.cfg.display_cameras = self.display_cameras_widget.isChecked()
|
||||||
|
self.cfg.play_sounds = self.play_sounds_widget.isChecked()
|
||||||
|
self.cfg.single_task = self.single_task_widget.text()
|
||||||
|
self.cfg.num_image_writer_processes = self.num_writer_processes_widget.value()
|
||||||
|
self.cfg.num_image_writer_threads_per_camera = self.num_writer_threads_widget.value()
|
||||||
|
self.cfg.use_depth_image = self.use_depth_widget.isChecked()
|
||||||
|
self.cfg.use_base = self.use_base_widget.isChecked()
|
||||||
|
self.cfg.push_to_hub = self.push_to_hub_widget.isChecked()
|
||||||
|
self.cfg.policy = None
|
||||||
|
self.cfg.control_type = "record"
|
||||||
|
|
||||||
|
def start_recording(self):
|
||||||
|
"""Start recording with current configuration"""
|
||||||
|
if not hasattr(self, 'cfg') or self.cfg is None:
|
||||||
|
from main import get_arguments
|
||||||
|
self.cfg = get_arguments()
|
||||||
|
|
||||||
|
if self.fps_widget is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.get_config_values()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create robot instance with current config
|
||||||
|
from common.rosrobot_factory import RobotFactory
|
||||||
|
self.cfg.config_file = self.config_widget.text()
|
||||||
|
self.robot = RobotFactory.create(
|
||||||
|
config_file=self.cfg.config_file,
|
||||||
|
args=self.cfg
|
||||||
|
)
|
||||||
|
from common.utils.data_utils import record
|
||||||
|
record(self.robot, self.cfg)
|
||||||
|
|
||||||
|
self.is_recording = True
|
||||||
|
self.record_button.setEnabled(False)
|
||||||
|
self.stop_button.setEnabled(True)
|
||||||
|
print("Recording started with configuration:", vars(self.cfg))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to start recording: {e}")
|
||||||
|
self.is_recording = False
|
||||||
|
|
||||||
|
def stop_recording(self):
|
||||||
|
"""Stop recording"""
|
||||||
|
self.is_recording = False
|
||||||
|
self.record_button.setEnabled(True)
|
||||||
|
self.stop_button.setEnabled(False)
|
||||||
|
|
||||||
|
# 模拟ESC键按下事件
|
||||||
|
from pynput.keyboard import Key, Controller
|
||||||
|
keyboard = Controller()
|
||||||
|
keyboard.press(Key.esc)
|
||||||
|
keyboard.release(Key.esc)
|
||||||
|
|
||||||
|
print("Recording stopped")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = QApplication(sys.argv)
|
||||||
|
window = MainWindow()
|
||||||
|
window.show()
|
||||||
|
sys.exit(app.exec_())
|
||||||
56
lerobot_aloha/main.py
Normal file
56
lerobot_aloha/main.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import argparse
|
||||||
|
from common.rosrobot_factory import RobotFactory
|
||||||
|
from common.utils.data_utils import record
|
||||||
|
from common.utils.replay_utils import replay
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def get_arguments():
|
||||||
|
"""
|
||||||
|
Parse command line arguments.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
argparse.Namespace: Parsed arguments
|
||||||
|
"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.fps = 30
|
||||||
|
args.resume = False
|
||||||
|
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
|
||||||
|
args.root = "./data5"
|
||||||
|
args.episode = 0 # replay episode
|
||||||
|
args.num_image_writer_processes = 0
|
||||||
|
args.num_image_writer_threads_per_camera = 4
|
||||||
|
args.video = True
|
||||||
|
args.num_episodes = 100
|
||||||
|
args.episode_time_s = 30000
|
||||||
|
args.play_sounds = False
|
||||||
|
args.display_cameras = True
|
||||||
|
args.single_task = "move the bottle from the right to the scale right"
|
||||||
|
args.use_depth_image = True
|
||||||
|
args.use_base = False
|
||||||
|
args.push_to_hub = False
|
||||||
|
args.policy = None
|
||||||
|
args.control_type = "record"
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def control_robot(cfg):
|
||||||
|
"""
|
||||||
|
Control robot based on configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: Configuration object
|
||||||
|
"""
|
||||||
|
# Create robot instance using factory pattern
|
||||||
|
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
|
||||||
|
|
||||||
|
# Execute appropriate control mode
|
||||||
|
if cfg.control_type == "record":
|
||||||
|
record(robot, cfg)
|
||||||
|
elif cfg.control_type == "replay":
|
||||||
|
replay(robot, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cfg = get_arguments()
|
||||||
|
control_robot(cfg)
|
||||||
78
lerobot_aloha/replay_data.py
Normal file
78
lerobot_aloha/replay_data.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from common.agilex_robot import AgilexRobot
|
||||||
|
import torch
|
||||||
|
import cv2
|
||||||
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.common.datasets.utils import cycle
|
||||||
|
|
||||||
|
|
||||||
|
def get_arguments():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.fps = 30
|
||||||
|
args.resume = False
|
||||||
|
args.repo_id = "tangger/test"
|
||||||
|
# args.root = "/home/ubuntu/LYT/lerobot_aloha/datasets/move_a_tube_on_the_scale_without_front"
|
||||||
|
# args.root="/home/ubuntu/LYT/aloha_lerobot/data4"
|
||||||
|
args.root = "/home/ubuntu/LYT/lerobot_aloha/datasets/abcde"
|
||||||
|
args.num_image_writer_processes = 0
|
||||||
|
args.num_image_writer_threads_per_camera = 4
|
||||||
|
args.video = True
|
||||||
|
args.num_episodes = 50
|
||||||
|
args.episode_time_s = 30000
|
||||||
|
args.play_sounds = False
|
||||||
|
args.display_cameras = True
|
||||||
|
args.single_task = "test test"
|
||||||
|
args.use_depth_image = False
|
||||||
|
args.use_base = False
|
||||||
|
args.push_to_hub = False
|
||||||
|
args.policy= None
|
||||||
|
args.teleoprate = False
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
cfg = get_arguments()
|
||||||
|
robot = AgilexRobot(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
|
||||||
|
inference_time_s = 360
|
||||||
|
fps = 15
|
||||||
|
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
dataset = LeRobotDataset(
|
||||||
|
cfg.repo_id,
|
||||||
|
root=cfg.root,
|
||||||
|
)
|
||||||
|
shuffle = True
|
||||||
|
sampler = None
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
dataset,
|
||||||
|
num_workers=4,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=sampler,
|
||||||
|
pin_memory=device != "cpu",
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
|
# 控制播放速度fps=30
|
||||||
|
for data in dl_iter:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
action = data["action"]
|
||||||
|
# cam_high = data["observation.images.cam_high"]
|
||||||
|
# cam_left = data["observation.images.cam_left"]
|
||||||
|
# cam_right = data["observation.images.cam_right"]
|
||||||
|
# print(data)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
action = action.squeeze(0)
|
||||||
|
# Move to cpu, if not already the case
|
||||||
|
action = action.to("cpu")
|
||||||
|
# Order the robot to move
|
||||||
|
robot.send_action(action)
|
||||||
|
print(action)
|
||||||
|
dt_s = time.perf_counter() - start_time
|
||||||
|
busy_wait(1 / fps - dt_s)
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||||
|
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
|
||||||
from lerobot.common.robot_devices.utils import busy_wait
|
from lerobot.common.robot_devices.utils import busy_wait
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
from agilex_robot import AgilexRobot
|
from common.agilex_robot import AgilexRobot
|
||||||
import torch
|
import torch
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
def get_arguments():
|
def get_arguments():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
@@ -29,13 +32,18 @@ def get_arguments():
|
|||||||
|
|
||||||
|
|
||||||
cfg = get_arguments()
|
cfg = get_arguments()
|
||||||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
robot = AgilexRobot(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
|
||||||
inference_time_s = 360
|
inference_time_s = 360
|
||||||
fps = 30
|
fps = 30
|
||||||
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
||||||
|
|
||||||
ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
|
# ckpt_path = "/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_move_bottle_on_scale_without_front/checkpoints/last/pretrained_model"
|
||||||
policy = ACTPolicy.from_pretrained(ckpt_path)
|
ckpt_path ="/home/ubuntu/LYT/lerobot_aloha/outputs/train/act_abcde/checkpoints/last/pretrained_model"
|
||||||
|
policy = ACTPolicy.from_pretrained(pretrained_name_or_path=ckpt_path)
|
||||||
|
|
||||||
|
# ckpt_path ="/home/ubuntu/LYT/lerobot_aloha/outputs/train/diffusion_abcde/checkpoints/020000/pretrained_model"
|
||||||
|
# policy = DiffusionPolicy.from_pretrained(pretrained_name_or_path=ckpt_path)
|
||||||
|
|
||||||
policy.to(device)
|
policy.to(device)
|
||||||
|
|
||||||
for _ in range(inference_time_s * fps):
|
for _ in range(inference_time_s * fps):
|
||||||
@@ -46,16 +54,23 @@ for _ in range(inference_time_s * fps):
|
|||||||
if observation is None:
|
if observation is None:
|
||||||
print("Observation is None, skipping...")
|
print("Observation is None, skipping...")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# visualize the image in the obervation
|
||||||
|
# cv2.imshow("observation", observation["observation.image"])
|
||||||
|
|
||||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||||
# with batch dimension
|
# with batch dimension
|
||||||
for name in observation:
|
for name in observation:
|
||||||
if "image" in name:
|
if "image" in name:
|
||||||
|
img = observation[name].numpy()
|
||||||
|
# cv2.imshow(name, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||||
|
# cv2.imwrite(f"{name}.png", cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
||||||
observation[name] = observation[name].type(torch.float32) / 255
|
observation[name] = observation[name].type(torch.float32) / 255
|
||||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||||
observation[name] = observation[name].unsqueeze(0)
|
observation[name] = observation[name].unsqueeze(0)
|
||||||
observation[name] = observation[name].to(device)
|
observation[name] = observation[name].to(device)
|
||||||
|
|
||||||
|
last_pos = observation["observation.state"]
|
||||||
# Compute the next action with the policy
|
# Compute the next action with the policy
|
||||||
# based on the current observation
|
# based on the current observation
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
@@ -65,6 +80,8 @@ for _ in range(inference_time_s * fps):
|
|||||||
action = action.to("cpu")
|
action = action.to("cpu")
|
||||||
# Order the robot to move
|
# Order the robot to move
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
print("left pos:", action[:7])
|
||||||
|
print("right pos:", action[7:])
|
||||||
|
|
||||||
dt_s = time.perf_counter() - start_time
|
dt_s = time.perf_counter() - start_time
|
||||||
busy_wait(1 / fps - dt_s)
|
busy_wait(1 / fps - dt_s)
|
||||||
1
openpi
Submodule
1
openpi
Submodule
Submodule openpi added at 36dc3c037e
Reference in New Issue
Block a user