使用原生的数据搜集代码
This commit is contained in:
@@ -1,3 +0,0 @@
|
||||
python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data
|
||||
|
||||
python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,146 +0,0 @@
|
||||
robot_type: aloha_agilex
|
||||
ros_node_name: record_episodes
|
||||
cameras:
|
||||
cam_front:
|
||||
img_topic_name: /camera_f/color/image_raw
|
||||
depth_topic_name: /camera_f/depth/image_raw
|
||||
width: 480
|
||||
height: 640
|
||||
rgb_shape: [480, 640, 3]
|
||||
cam_left:
|
||||
img_topic_name: /camera_l/color/image_raw
|
||||
depth_topic_name: /camera_l/depth/image_raw
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
cam_right:
|
||||
img_topic_name: /camera_r/color/image_raw
|
||||
depth_topic_name: /camera_r/depth/image_raw
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
cam_high:
|
||||
img_topic_name: /camera/color/image_raw
|
||||
depth_topic_name: /camera/depth/image_rect_raw
|
||||
rgb_shape: [480, 640, 3]
|
||||
width: 480
|
||||
height: 640
|
||||
|
||||
arm:
|
||||
master_left:
|
||||
topic_name: /master/joint_left
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none"
|
||||
]
|
||||
master_right:
|
||||
topic_name: /master/joint_right
|
||||
motors: [
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
puppet_left:
|
||||
topic_name: /puppet/joint_left
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none"
|
||||
]
|
||||
puppet_right:
|
||||
topic_name: /puppet/joint_right
|
||||
motors: [
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
|
||||
# follow the joint name in ros
|
||||
state:
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none",
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
|
||||
velocity:
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none",
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
|
||||
effort:
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none",
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
|
||||
action:
|
||||
motors: [
|
||||
"left_joint0",
|
||||
"left_joint1",
|
||||
"left_joint2",
|
||||
"left_joint3",
|
||||
"left_joint4",
|
||||
"left_joint5",
|
||||
"left_none",
|
||||
"right_joint0",
|
||||
"right_joint1",
|
||||
"right_joint2",
|
||||
"right_joint3",
|
||||
"right_joint4",
|
||||
"right_joint5",
|
||||
"right_none"
|
||||
]
|
||||
305
collect_data/aloha_mobile.py
Normal file
305
collect_data/aloha_mobile.py
Normal file
@@ -0,0 +1,305 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import dm_env
|
||||
|
||||
import collections
|
||||
from collections import deque
|
||||
|
||||
import rospy
|
||||
from sensor_msgs.msg import JointState
|
||||
from sensor_msgs.msg import Image
|
||||
from nav_msgs.msg import Odometry
|
||||
from cv_bridge import CvBridge
|
||||
from utils import display_camera_grid, save_data
|
||||
|
||||
|
||||
class AlohaRobotRos:
|
||||
def __init__(self, args):
|
||||
self.robot_base_deque = None
|
||||
self.puppet_arm_right_deque = None
|
||||
self.puppet_arm_left_deque = None
|
||||
self.master_arm_right_deque = None
|
||||
self.master_arm_left_deque = None
|
||||
self.img_front_deque = None
|
||||
self.img_right_deque = None
|
||||
self.img_left_deque = None
|
||||
self.img_front_depth_deque = None
|
||||
self.img_right_depth_deque = None
|
||||
self.img_left_depth_deque = None
|
||||
self.bridge = None
|
||||
self.args = args
|
||||
self.init()
|
||||
self.init_ros()
|
||||
|
||||
def init(self):
|
||||
self.bridge = CvBridge()
|
||||
self.img_left_deque = deque()
|
||||
self.img_right_deque = deque()
|
||||
self.img_front_deque = deque()
|
||||
self.img_left_depth_deque = deque()
|
||||
self.img_right_depth_deque = deque()
|
||||
self.img_front_depth_deque = deque()
|
||||
self.master_arm_left_deque = deque()
|
||||
self.master_arm_right_deque = deque()
|
||||
self.puppet_arm_left_deque = deque()
|
||||
self.puppet_arm_right_deque = deque()
|
||||
self.robot_base_deque = deque()
|
||||
|
||||
def get_frame(self):
|
||||
print(len(self.img_left_deque), len(self.img_right_deque), len(self.img_front_deque),
|
||||
len(self.img_left_depth_deque), len(self.img_right_depth_deque), len(self.img_front_depth_deque))
|
||||
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
|
||||
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
|
||||
return False
|
||||
if self.args.use_depth_image:
|
||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
|
||||
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
|
||||
else:
|
||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
|
||||
|
||||
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.master_arm_left_deque) == 0 or self.master_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.master_arm_right_deque) == 0 or self.master_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
|
||||
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_deque.popleft()
|
||||
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
|
||||
# print("img_left:", img_left.shape)
|
||||
|
||||
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_deque.popleft()
|
||||
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_deque.popleft()
|
||||
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.master_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.master_arm_left_deque.popleft()
|
||||
master_arm_left = self.master_arm_left_deque.popleft()
|
||||
|
||||
while self.master_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.master_arm_right_deque.popleft()
|
||||
master_arm_right = self.master_arm_right_deque.popleft()
|
||||
|
||||
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
||||
|
||||
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
||||
|
||||
img_left_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_depth_deque.popleft()
|
||||
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
|
||||
top, bottom, left, right = 40, 40, 0, 0
|
||||
img_left_depth = cv2.copyMakeBorder(img_left_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
img_right_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_depth_deque.popleft()
|
||||
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
|
||||
top, bottom, left, right = 40, 40, 0, 0
|
||||
img_right_depth = cv2.copyMakeBorder(img_right_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
img_front_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_depth_deque.popleft()
|
||||
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
|
||||
top, bottom, left, right = 40, 40, 0, 0
|
||||
img_front_depth = cv2.copyMakeBorder(img_front_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
robot_base = None
|
||||
if self.args.use_robot_base:
|
||||
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.robot_base_deque.popleft()
|
||||
robot_base = self.robot_base_deque.popleft()
|
||||
|
||||
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base)
|
||||
|
||||
def img_left_callback(self, msg):
|
||||
if len(self.img_left_deque) >= 2000:
|
||||
self.img_left_deque.popleft()
|
||||
self.img_left_deque.append(msg)
|
||||
|
||||
def img_right_callback(self, msg):
|
||||
if len(self.img_right_deque) >= 2000:
|
||||
self.img_right_deque.popleft()
|
||||
self.img_right_deque.append(msg)
|
||||
|
||||
def img_front_callback(self, msg):
|
||||
if len(self.img_front_deque) >= 2000:
|
||||
self.img_front_deque.popleft()
|
||||
self.img_front_deque.append(msg)
|
||||
|
||||
def img_left_depth_callback(self, msg):
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
if len(self.img_left_depth_deque) >= 2000:
|
||||
self.img_left_depth_deque.popleft()
|
||||
self.img_left_depth_deque.append(msg)
|
||||
|
||||
def img_right_depth_callback(self, msg):
|
||||
if len(self.img_right_depth_deque) >= 2000:
|
||||
self.img_right_depth_deque.popleft()
|
||||
self.img_right_depth_deque.append(msg)
|
||||
|
||||
def img_front_depth_callback(self, msg):
|
||||
if len(self.img_front_depth_deque) >= 2000:
|
||||
self.img_front_depth_deque.popleft()
|
||||
self.img_front_depth_deque.append(msg)
|
||||
|
||||
def master_arm_left_callback(self, msg):
|
||||
if len(self.master_arm_left_deque) >= 2000:
|
||||
self.master_arm_left_deque.popleft()
|
||||
self.master_arm_left_deque.append(msg)
|
||||
|
||||
def master_arm_right_callback(self, msg):
|
||||
if len(self.master_arm_right_deque) >= 2000:
|
||||
self.master_arm_right_deque.popleft()
|
||||
self.master_arm_right_deque.append(msg)
|
||||
|
||||
def puppet_arm_left_callback(self, msg):
|
||||
if len(self.puppet_arm_left_deque) >= 2000:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
self.puppet_arm_left_deque.append(msg)
|
||||
|
||||
def puppet_arm_right_callback(self, msg):
|
||||
if len(self.puppet_arm_right_deque) >= 2000:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
self.puppet_arm_right_deque.append(msg)
|
||||
|
||||
def robot_base_callback(self, msg):
|
||||
if len(self.robot_base_deque) >= 2000:
|
||||
self.robot_base_deque.popleft()
|
||||
self.robot_base_deque.append(msg)
|
||||
|
||||
def init_ros(self):
|
||||
rospy.init_node('record_episodes', anonymous=True)
|
||||
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
|
||||
if self.args.use_depth_image:
|
||||
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
|
||||
rospy.Subscriber(self.args.master_arm_left_topic, JointState, self.master_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.master_arm_right_topic, JointState, self.master_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
|
||||
|
||||
def process(self):
|
||||
timesteps = []
|
||||
actions = []
|
||||
# 图像数据
|
||||
image = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8)
|
||||
image_dict = dict()
|
||||
for cam_name in self.args.camera_names:
|
||||
image_dict[cam_name] = image
|
||||
count = 0
|
||||
|
||||
# input_key = input("please input s:")
|
||||
# while input_key != 's' and not rospy.is_shutdown():
|
||||
# input_key = input("please input s:")
|
||||
|
||||
rate = rospy.Rate(self.args.frame_rate)
|
||||
print_flag = True
|
||||
|
||||
while (count < self.args.max_timesteps + 1) and not rospy.is_shutdown():
|
||||
# 2 收集数据
|
||||
result = self.get_frame()
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
if not result:
|
||||
if print_flag:
|
||||
print("syn fail\n")
|
||||
print_flag = False
|
||||
rate.sleep()
|
||||
continue
|
||||
print_flag = True
|
||||
count += 1
|
||||
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) = result
|
||||
# 2.1 图像信息
|
||||
image_dict = dict()
|
||||
image_dict[self.args.camera_names[0]] = img_front
|
||||
image_dict[self.args.camera_names[1]] = img_left
|
||||
image_dict[self.args.camera_names[2]] = img_right
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
display_camera_grid(image_dict)
|
||||
|
||||
# 2.2 从臂的信息从臂的状态 机械臂示教模式时 会自动订阅
|
||||
obs = collections.OrderedDict() # 有序的字典
|
||||
obs['images'] = image_dict
|
||||
if self.args.use_depth_image:
|
||||
image_dict_depth = dict()
|
||||
image_dict_depth[self.args.camera_names[0]] = img_front_depth
|
||||
image_dict_depth[self.args.camera_names[1]] = img_left_depth
|
||||
image_dict_depth[self.args.camera_names[2]] = img_right_depth
|
||||
obs['images_depth'] = image_dict_depth
|
||||
obs['qpos'] = np.concatenate((np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
|
||||
obs['qvel'] = np.concatenate((np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
|
||||
obs['effort'] = np.concatenate((np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
|
||||
if self.args.use_robot_base:
|
||||
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
|
||||
else:
|
||||
obs['base_vel'] = [0.0, 0.0]
|
||||
|
||||
# 第一帧 只包含first, fisrt只保存StepType.FIRST
|
||||
if count == 1:
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.FIRST,
|
||||
reward=None,
|
||||
discount=None,
|
||||
observation=obs)
|
||||
timesteps.append(ts)
|
||||
continue
|
||||
|
||||
# 时间步
|
||||
ts = dm_env.TimeStep(
|
||||
step_type=dm_env.StepType.MID,
|
||||
reward=None,
|
||||
discount=None,
|
||||
observation=obs)
|
||||
|
||||
# 主臂保存状态
|
||||
action = np.concatenate((np.array(master_arm_left.position), np.array(master_arm_right.position)), axis=0)
|
||||
actions.append(action)
|
||||
timesteps.append(ts)
|
||||
print(f"\n{self.args.episode_idx} | Frame data: {count}\r", end="")
|
||||
if rospy.is_shutdown():
|
||||
exit(-1)
|
||||
rate.sleep()
|
||||
|
||||
print("len(timesteps): ", len(timesteps))
|
||||
print("len(actions) : ", len(actions))
|
||||
return timesteps, actions
|
||||
166
collect_data/collect_data.py
Executable file
166
collect_data/collect_data.py
Executable file
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from aloha_mobile import AlohaRobotRos
|
||||
from utils import save_data, init_keyboard_listener
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset_dir.',
|
||||
default="./data", required=False)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
|
||||
default="aloha_mobile_dummy", required=False)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',
|
||||
default=0, required=False)
|
||||
parser.add_argument('--max_timesteps', action='store', type=int, help='Max_timesteps.',
|
||||
default=500, required=False)
|
||||
parser.add_argument('--camera_names', action='store', type=str, help='camera_names',
|
||||
default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False)
|
||||
parser.add_argument('--num_episodes', action='store', type=int, help='Num_episodes.',
|
||||
default=1, required=False)
|
||||
|
||||
|
||||
# topic name of color image
|
||||
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
|
||||
default='/camera_f/color/image_raw', required=False)
|
||||
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
|
||||
default='/camera_l/color/image_raw', required=False)
|
||||
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
|
||||
default='/camera_r/color/image_raw', required=False)
|
||||
|
||||
# topic name of depth image
|
||||
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
|
||||
default='/camera_f/depth/image_raw', required=False)
|
||||
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
|
||||
default='/camera_l/depth/image_raw', required=False)
|
||||
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
|
||||
default='/camera_r/depth/image_raw', required=False)
|
||||
|
||||
# topic name of arm
|
||||
parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
|
||||
default='/master/joint_left', required=False)
|
||||
parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
|
||||
default='/master/joint_right', required=False)
|
||||
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
|
||||
default='/puppet/joint_left', required=False)
|
||||
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
|
||||
default='/puppet/joint_right', required=False)
|
||||
|
||||
# topic name of robot_base
|
||||
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
|
||||
default='/odom', required=False)
|
||||
|
||||
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
|
||||
default=False, required=False)
|
||||
|
||||
# collect depth image
|
||||
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
|
||||
default=False, required=False)
|
||||
|
||||
parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate',
|
||||
default=30, required=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_arguments()
|
||||
ros_operator = AlohaRobotRos(args)
|
||||
dataset_dir = os.path.join(args.dataset_dir, args.task_name)
|
||||
# 确保数据集目录存在
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
# 单集收集模式
|
||||
if args.num_episodes == 1:
|
||||
print(f"Recording single episode {args.episode_idx}...")
|
||||
timesteps, actions = ros_operator.process()
|
||||
|
||||
if len(actions) < args.max_timesteps:
|
||||
print(f"\033[31m\nSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m\n")
|
||||
return -1
|
||||
|
||||
dataset_path = os.path.join(dataset_dir, f"episode_{args.episode_idx}")
|
||||
save_data(args, timesteps, actions, dataset_path)
|
||||
print(f"\033[32mEpisode {args.episode_idx} saved successfully at {dataset_path}\033[0m")
|
||||
return 0
|
||||
# 多集收集模式
|
||||
print("""
|
||||
\033[1;36mKeyboard Controls:\033[0m
|
||||
← \033[1mLeft Arrow\033[0m: Start Recording
|
||||
→ \033[1mRight Arrow\033[0m: Save Current Data
|
||||
↓ \033[1mDown Arrow\033[0m: Discard Current Data
|
||||
↑ \033[1mUp Arrow\033[0m: Replay Data (if implemented)
|
||||
\033[1mESC\033[0m: Exit Program
|
||||
""")
|
||||
# 初始化键盘监听器
|
||||
listener, events = init_keyboard_listener()
|
||||
episode_idx = args.episode_idx
|
||||
collected_episodes = 0
|
||||
|
||||
try:
|
||||
while collected_episodes < args.num_episodes:
|
||||
if events["exit_early"]:
|
||||
print("\033[33mOperation terminated by user\033[0m")
|
||||
break
|
||||
|
||||
if events["record_start"]:
|
||||
# 重置事件状态,开始新的录制
|
||||
events["record_start"] = False
|
||||
events["save_data"] = False
|
||||
events["discard_data"] = False
|
||||
|
||||
print(f"\n\033[1;32mRecording episode {episode_idx}...\033[0m")
|
||||
timesteps, actions = ros_operator.process()
|
||||
print(f"\033[1;33mRecorded {len(actions)} timesteps. (→ to save, ↓ to discard)\033[0m")
|
||||
|
||||
# 等待用户决定保存或丢弃
|
||||
while True:
|
||||
if events["save_data"]:
|
||||
events["save_data"] = False
|
||||
|
||||
if len(actions) < args.max_timesteps:
|
||||
print(f"\033[31mSave failure: Recorded only {len(actions)}/{args.max_timesteps} timesteps.\033[0m")
|
||||
else:
|
||||
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
|
||||
save_data(args, timesteps, actions, dataset_path)
|
||||
print(f"\033[32mEpisode {episode_idx} saved successfully at {dataset_path}\033[0m")
|
||||
episode_idx += 1
|
||||
collected_episodes += 1
|
||||
print(f"\033[1mProgress: {collected_episodes}/{args.num_episodes} episodes collected. (← to start new episode)\033[0m")
|
||||
break
|
||||
|
||||
if events["discard_data"]:
|
||||
events["discard_data"] = False
|
||||
print("\033[33mData discarded. Press ← to start a new recording.\033[0m")
|
||||
break
|
||||
|
||||
if events["exit_early"]:
|
||||
print("\033[33mOperation terminated by user\033[0m")
|
||||
return 0
|
||||
|
||||
time.sleep(0.1) # 减少CPU使用率
|
||||
|
||||
time.sleep(0.1) # 减少CPU使用率
|
||||
|
||||
if collected_episodes == args.num_episodes:
|
||||
print(f"\n\033[1;32mData collection complete! All {args.num_episodes} episodes collected.\033[0m")
|
||||
|
||||
finally:
|
||||
# 确保监听器被清理
|
||||
if listener:
|
||||
listener.stop()
|
||||
print("Keyboard listener stopped")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
exit_code = main()
|
||||
exit(exit_code if exit_code is not None else 0)
|
||||
except KeyboardInterrupt:
|
||||
print("\n\033[33mProgram interrupted by user\033[0m")
|
||||
exit(0)
|
||||
except Exception as e:
|
||||
print(f"\n\033[31mError: {e}\033[0m")
|
||||
|
||||
# python collect_data.py --dataset_dir ~/data --max_timesteps 500 --episode_idx 0
|
||||
416
collect_data/collect_data_gui.py
Normal file
416
collect_data/collect_data_gui.py
Normal file
@@ -0,0 +1,416 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import argparse
|
||||
from aloha_mobile import AlohaRobotRos
|
||||
from utils import save_data, init_keyboard_listener
|
||||
from PyQt5.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
||||
QLabel, QLineEdit, QPushButton, QCheckBox, QSpinBox,
|
||||
QGroupBox, QFormLayout, QTabWidget, QTextEdit, QFileDialog,
|
||||
QMessageBox, QProgressBar, QComboBox)
|
||||
from PyQt5.QtCore import Qt, QThread, pyqtSignal, pyqtSlot
|
||||
from PyQt5.QtGui import QFont, QIcon, QTextCursor
|
||||
|
||||
class DataCollectionThread(QThread):
|
||||
"""处理数据收集的线程"""
|
||||
update_signal = pyqtSignal(str)
|
||||
progress_signal = pyqtSignal(int)
|
||||
finish_signal = pyqtSignal(bool, str)
|
||||
|
||||
def __init__(self, args, parent=None):
|
||||
super(DataCollectionThread, self).__init__(parent)
|
||||
self.args = args
|
||||
self.is_running = True
|
||||
self.ros_operator = None
|
||||
|
||||
def run(self):
|
||||
try:
|
||||
self.update_signal.emit("正在初始化ROS操作...\n")
|
||||
self.ros_operator = AlohaRobotRos(self.args)
|
||||
dataset_dir = os.path.join(self.args.dataset_dir, self.args.task_name)
|
||||
os.makedirs(dataset_dir, exist_ok=True)
|
||||
|
||||
# 单集收集模式
|
||||
if self.args.num_episodes == 1:
|
||||
self.update_signal.emit(f"开始录制第 {self.args.episode_idx} 集...\n")
|
||||
timesteps, actions = self.ros_operator.process()
|
||||
|
||||
if len(actions) < self.args.max_timesteps:
|
||||
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
|
||||
self.finish_signal.emit(False, f"只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步")
|
||||
return
|
||||
|
||||
dataset_path = os.path.join(dataset_dir, f"episode_{self.args.episode_idx}")
|
||||
save_data(self.args, timesteps, actions, dataset_path)
|
||||
self.update_signal.emit(f"第 {self.args.episode_idx} 集成功保存到 {dataset_path}.\n")
|
||||
self.finish_signal.emit(True, "数据收集完成")
|
||||
|
||||
# 多集收集模式
|
||||
else:
|
||||
self.update_signal.emit("""
|
||||
键盘控制:
|
||||
← 左箭头: 开始录制
|
||||
→ 右箭头: 保存当前数据
|
||||
↓ 下箭头: 丢弃当前数据
|
||||
ESC: 退出程序
|
||||
""")
|
||||
# 初始化键盘监听器
|
||||
listener, events = init_keyboard_listener()
|
||||
episode_idx = self.args.episode_idx
|
||||
collected_episodes = 0
|
||||
|
||||
try:
|
||||
while collected_episodes < self.args.num_episodes and self.is_running:
|
||||
if events["exit_early"]:
|
||||
self.update_signal.emit("操作被用户终止.\n")
|
||||
break
|
||||
|
||||
if events["record_start"]:
|
||||
# 重置事件状态,开始新的录制
|
||||
events["record_start"] = False
|
||||
events["save_data"] = False
|
||||
events["discard_data"] = False
|
||||
|
||||
self.update_signal.emit(f"\n正在录制第 {episode_idx} 集...\n")
|
||||
timesteps, actions = self.ros_operator.process()
|
||||
self.update_signal.emit(f"已录制 {len(actions)} 个时间步. (→ 保存, ↓ 丢弃)\n")
|
||||
|
||||
# 等待用户决定保存或丢弃
|
||||
while self.is_running:
|
||||
if events["save_data"]:
|
||||
events["save_data"] = False
|
||||
|
||||
if len(actions) < self.args.max_timesteps:
|
||||
self.update_signal.emit(f"保存失败: 只录制了 {len(actions)}/{self.args.max_timesteps} 个时间步.\n")
|
||||
else:
|
||||
dataset_path = os.path.join(dataset_dir, f"episode_{episode_idx}")
|
||||
save_data(self.args, timesteps, actions, dataset_path)
|
||||
self.update_signal.emit(f"第 {episode_idx} 集成功保存到 {dataset_path}.\n")
|
||||
episode_idx += 1
|
||||
collected_episodes += 1
|
||||
progress_percentage = int(collected_episodes * 100 / self.args.num_episodes)
|
||||
self.progress_signal.emit(progress_percentage)
|
||||
self.update_signal.emit(f"进度: {collected_episodes}/{self.args.num_episodes} 集已收集. (← 开始新一集)\n")
|
||||
break
|
||||
|
||||
if events["discard_data"]:
|
||||
events["discard_data"] = False
|
||||
self.update_signal.emit("数据已丢弃. 请按 ← 开始新的录制.\n")
|
||||
break
|
||||
|
||||
if events["exit_early"]:
|
||||
self.update_signal.emit("操作被用户终止.\n")
|
||||
self.is_running = False
|
||||
break
|
||||
|
||||
time.sleep(0.1) # 减少CPU使用率
|
||||
|
||||
time.sleep(0.1) # 减少CPU使用率
|
||||
|
||||
if collected_episodes == self.args.num_episodes:
|
||||
self.update_signal.emit(f"\n数据收集完成! 所有 {self.args.num_episodes} 集已收集.\n")
|
||||
self.finish_signal.emit(True, "全部数据集收集完成")
|
||||
else:
|
||||
self.finish_signal.emit(False, "数据收集未完成")
|
||||
|
||||
finally:
|
||||
# 确保监听器被清理
|
||||
if listener:
|
||||
listener.stop()
|
||||
self.update_signal.emit("键盘监听器已停止\n")
|
||||
|
||||
except Exception as e:
|
||||
self.update_signal.emit(f"错误: {str(e)}\n")
|
||||
self.finish_signal.emit(False, str(e))
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
self.wait()
|
||||
|
||||
class AlohaDataCollectionGUI(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.setWindowTitle("ALOHA 数据收集工具")
|
||||
self.setGeometry(100, 100, 800, 700)
|
||||
|
||||
# 主组件
|
||||
self.central_widget = QWidget()
|
||||
self.setCentralWidget(self.central_widget)
|
||||
self.main_layout = QVBoxLayout(self.central_widget)
|
||||
|
||||
# 创建选项卡
|
||||
self.tab_widget = QTabWidget()
|
||||
self.main_layout.addWidget(self.tab_widget)
|
||||
|
||||
# 创建配置选项卡
|
||||
self.config_tab = QWidget()
|
||||
self.tab_widget.addTab(self.config_tab, "配置")
|
||||
|
||||
# 创建数据收集选项卡
|
||||
self.collection_tab = QWidget()
|
||||
self.tab_widget.addTab(self.collection_tab, "数据收集")
|
||||
|
||||
self.setup_config_tab()
|
||||
self.setup_collection_tab()
|
||||
|
||||
# 初始化数据收集线程
|
||||
self.collection_thread = None
|
||||
|
||||
def setup_config_tab(self):
|
||||
config_layout = QVBoxLayout(self.config_tab)
|
||||
|
||||
# 基本配置组
|
||||
basic_group = QGroupBox("基本配置")
|
||||
basic_form = QFormLayout()
|
||||
|
||||
self.dataset_dir = QLineEdit("./data")
|
||||
self.browse_button = QPushButton("浏览")
|
||||
self.browse_button.clicked.connect(self.browse_dataset_dir)
|
||||
|
||||
dir_layout = QHBoxLayout()
|
||||
dir_layout.addWidget(self.dataset_dir)
|
||||
dir_layout.addWidget(self.browse_button)
|
||||
|
||||
self.task_name = QLineEdit("aloha_mobile_dummy")
|
||||
self.episode_idx = QSpinBox()
|
||||
self.episode_idx.setRange(0, 1000)
|
||||
self.max_timesteps = QSpinBox()
|
||||
self.max_timesteps.setRange(1, 10000)
|
||||
self.max_timesteps.setValue(500)
|
||||
self.num_episodes = QSpinBox()
|
||||
self.num_episodes.setRange(1, 100)
|
||||
self.num_episodes.setValue(1)
|
||||
self.frame_rate = QSpinBox()
|
||||
self.frame_rate.setRange(1, 120)
|
||||
self.frame_rate.setValue(30)
|
||||
|
||||
basic_form.addRow("数据集目录:", dir_layout)
|
||||
basic_form.addRow("任务名称:", self.task_name)
|
||||
basic_form.addRow("起始集索引:", self.episode_idx)
|
||||
basic_form.addRow("最大时间步:", self.max_timesteps)
|
||||
basic_form.addRow("集数:", self.num_episodes)
|
||||
basic_form.addRow("帧率:", self.frame_rate)
|
||||
|
||||
basic_group.setLayout(basic_form)
|
||||
config_layout.addWidget(basic_group)
|
||||
|
||||
# 相机话题组
|
||||
camera_group = QGroupBox("相机话题")
|
||||
camera_form = QFormLayout()
|
||||
|
||||
self.img_front_topic = QLineEdit('/camera_f/color/image_raw')
|
||||
self.img_left_topic = QLineEdit('/camera_l/color/image_raw')
|
||||
self.img_right_topic = QLineEdit('/camera_r/color/image_raw')
|
||||
|
||||
self.img_front_depth_topic = QLineEdit('/camera_f/depth/image_raw')
|
||||
self.img_left_depth_topic = QLineEdit('/camera_l/depth/image_raw')
|
||||
self.img_right_depth_topic = QLineEdit('/camera_r/depth/image_raw')
|
||||
|
||||
self.use_depth_image = QCheckBox("使用深度图像")
|
||||
|
||||
camera_form.addRow("前置相机:", self.img_front_topic)
|
||||
camera_form.addRow("左腕相机:", self.img_left_topic)
|
||||
camera_form.addRow("右腕相机:", self.img_right_topic)
|
||||
camera_form.addRow("前置深度:", self.img_front_depth_topic)
|
||||
camera_form.addRow("左腕深度:", self.img_left_depth_topic)
|
||||
camera_form.addRow("右腕深度:", self.img_right_depth_topic)
|
||||
camera_form.addRow("", self.use_depth_image)
|
||||
|
||||
camera_group.setLayout(camera_form)
|
||||
config_layout.addWidget(camera_group)
|
||||
|
||||
# 机器人话题组
|
||||
robot_group = QGroupBox("机器人话题")
|
||||
robot_form = QFormLayout()
|
||||
|
||||
self.master_arm_left_topic = QLineEdit('/master/joint_left')
|
||||
self.master_arm_right_topic = QLineEdit('/master/joint_right')
|
||||
self.puppet_arm_left_topic = QLineEdit('/puppet/joint_left')
|
||||
self.puppet_arm_right_topic = QLineEdit('/puppet/joint_right')
|
||||
self.robot_base_topic = QLineEdit('/odom')
|
||||
self.use_robot_base = QCheckBox("使用机器人底盘")
|
||||
|
||||
robot_form.addRow("主左臂:", self.master_arm_left_topic)
|
||||
robot_form.addRow("主右臂:", self.master_arm_right_topic)
|
||||
robot_form.addRow("从左臂:", self.puppet_arm_left_topic)
|
||||
robot_form.addRow("从右臂:", self.puppet_arm_right_topic)
|
||||
robot_form.addRow("底盘:", self.robot_base_topic)
|
||||
robot_form.addRow("", self.use_robot_base)
|
||||
|
||||
robot_group.setLayout(robot_form)
|
||||
config_layout.addWidget(robot_group)
|
||||
|
||||
# 相机名称配置
|
||||
camera_names_group = QGroupBox("相机名称")
|
||||
camera_names_layout = QVBoxLayout()
|
||||
|
||||
self.camera_names = ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
|
||||
self.camera_checkboxes = {}
|
||||
|
||||
for cam_name in self.camera_names:
|
||||
self.camera_checkboxes[cam_name] = QCheckBox(cam_name)
|
||||
self.camera_checkboxes[cam_name].setChecked(True)
|
||||
camera_names_layout.addWidget(self.camera_checkboxes[cam_name])
|
||||
|
||||
camera_names_group.setLayout(camera_names_layout)
|
||||
config_layout.addWidget(camera_names_group)
|
||||
|
||||
# 保存配置按钮
|
||||
self.save_config_button = QPushButton("保存配置")
|
||||
self.save_config_button.clicked.connect(self.save_config)
|
||||
config_layout.addWidget(self.save_config_button)
|
||||
|
||||
def setup_collection_tab(self):
|
||||
collection_layout = QVBoxLayout(self.collection_tab)
|
||||
|
||||
# 当前配置展示
|
||||
config_group = QGroupBox("当前配置")
|
||||
self.config_text = QTextEdit()
|
||||
self.config_text.setReadOnly(True)
|
||||
config_layout = QVBoxLayout()
|
||||
config_layout.addWidget(self.config_text)
|
||||
config_group.setLayout(config_layout)
|
||||
collection_layout.addWidget(config_group)
|
||||
|
||||
# 操作按钮
|
||||
buttons_layout = QHBoxLayout()
|
||||
|
||||
self.start_button = QPushButton("开始收集")
|
||||
self.start_button.setIcon(QIcon.fromTheme("media-playback-start"))
|
||||
self.start_button.clicked.connect(self.start_collection)
|
||||
|
||||
self.stop_button = QPushButton("停止")
|
||||
self.stop_button.setIcon(QIcon.fromTheme("media-playback-stop"))
|
||||
self.stop_button.setEnabled(False)
|
||||
self.stop_button.clicked.connect(self.stop_collection)
|
||||
|
||||
buttons_layout.addWidget(self.start_button)
|
||||
buttons_layout.addWidget(self.stop_button)
|
||||
collection_layout.addLayout(buttons_layout)
|
||||
|
||||
# 进度条
|
||||
self.progress_bar = QProgressBar()
|
||||
self.progress_bar.setValue(0)
|
||||
collection_layout.addWidget(self.progress_bar)
|
||||
|
||||
# 日志输出
|
||||
log_group = QGroupBox("操作日志")
|
||||
self.log_text = QTextEdit()
|
||||
self.log_text.setReadOnly(True)
|
||||
log_layout = QVBoxLayout()
|
||||
log_layout.addWidget(self.log_text)
|
||||
log_group.setLayout(log_layout)
|
||||
collection_layout.addWidget(log_group)
|
||||
|
||||
def browse_dataset_dir(self):
|
||||
directory = QFileDialog.getExistingDirectory(self, "选择数据集目录", self.dataset_dir.text())
|
||||
if directory:
|
||||
self.dataset_dir.setText(directory)
|
||||
|
||||
def save_config(self):
|
||||
# 更新配置显示
|
||||
selected_cameras = [cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()]
|
||||
|
||||
config_text = f"""
|
||||
任务名称: {self.task_name.text()}
|
||||
数据集目录: {self.dataset_dir.text()}
|
||||
起始集索引: {self.episode_idx.value()}
|
||||
最大时间步: {self.max_timesteps.value()}
|
||||
集数: {self.num_episodes.value()}
|
||||
帧率: {self.frame_rate.value()}
|
||||
使用深度图像: {"是" if self.use_depth_image.isChecked() else "否"}
|
||||
使用机器人底盘: {"是" if self.use_robot_base.isChecked() else "否"}
|
||||
相机: {', '.join(selected_cameras)}
|
||||
"""
|
||||
|
||||
self.config_text.setText(config_text)
|
||||
self.tab_widget.setCurrentIndex(1) # 切换到收集选项卡
|
||||
|
||||
QMessageBox.information(self, "配置已保存", "配置已更新,可以开始数据收集")
|
||||
|
||||
def start_collection(self):
|
||||
if not self.task_name.text():
|
||||
QMessageBox.warning(self, "配置错误", "请输入有效的任务名称")
|
||||
return
|
||||
|
||||
# 构建参数
|
||||
args = argparse.Namespace(
|
||||
dataset_dir=self.dataset_dir.text(),
|
||||
task_name=self.task_name.text(),
|
||||
episode_idx=self.episode_idx.value(),
|
||||
max_timesteps=self.max_timesteps.value(),
|
||||
num_episodes=self.num_episodes.value(),
|
||||
camera_names=[cam for cam, checkbox in self.camera_checkboxes.items() if checkbox.isChecked()],
|
||||
img_front_topic=self.img_front_topic.text(),
|
||||
img_left_topic=self.img_left_topic.text(),
|
||||
img_right_topic=self.img_right_topic.text(),
|
||||
img_front_depth_topic=self.img_front_depth_topic.text(),
|
||||
img_left_depth_topic=self.img_left_depth_topic.text(),
|
||||
img_right_depth_topic=self.img_right_depth_topic.text(),
|
||||
master_arm_left_topic=self.master_arm_left_topic.text(),
|
||||
master_arm_right_topic=self.master_arm_right_topic.text(),
|
||||
puppet_arm_left_topic=self.puppet_arm_left_topic.text(),
|
||||
puppet_arm_right_topic=self.puppet_arm_right_topic.text(),
|
||||
robot_base_topic=self.robot_base_topic.text(),
|
||||
use_robot_base=self.use_robot_base.isChecked(),
|
||||
use_depth_image=self.use_depth_image.isChecked(),
|
||||
frame_rate=self.frame_rate.value()
|
||||
)
|
||||
|
||||
# 更新UI状态
|
||||
self.start_button.setEnabled(False)
|
||||
self.stop_button.setEnabled(True)
|
||||
self.progress_bar.setValue(0)
|
||||
self.log_text.clear()
|
||||
self.log_text.append("正在初始化数据收集...\n")
|
||||
|
||||
# 创建并启动线程
|
||||
self.collection_thread = DataCollectionThread(args)
|
||||
self.collection_thread.update_signal.connect(self.update_log)
|
||||
self.collection_thread.progress_signal.connect(self.update_progress)
|
||||
self.collection_thread.finish_signal.connect(self.collection_finished)
|
||||
self.collection_thread.start()
|
||||
|
||||
def stop_collection(self):
|
||||
if self.collection_thread and self.collection_thread.isRunning():
|
||||
self.log_text.append("正在停止数据收集...\n")
|
||||
self.collection_thread.stop()
|
||||
self.collection_thread = None
|
||||
|
||||
self.start_button.setEnabled(True)
|
||||
self.stop_button.setEnabled(False)
|
||||
|
||||
@pyqtSlot(str)
|
||||
def update_log(self, message):
|
||||
self.log_text.append(message)
|
||||
# 自动滚动到底部
|
||||
cursor = self.log_text.textCursor()
|
||||
cursor.movePosition(QTextCursor.End)
|
||||
self.log_text.setTextCursor(cursor)
|
||||
|
||||
@pyqtSlot(int)
|
||||
def update_progress(self, value):
|
||||
self.progress_bar.setValue(value)
|
||||
|
||||
@pyqtSlot(bool, str)
|
||||
def collection_finished(self, success, message):
|
||||
self.start_button.setEnabled(True)
|
||||
self.stop_button.setEnabled(False)
|
||||
|
||||
if success:
|
||||
QMessageBox.information(self, "完成", message)
|
||||
else:
|
||||
QMessageBox.warning(self, "出错", f"数据收集失败: {message}")
|
||||
|
||||
# 更新episode_idx值
|
||||
if success and self.num_episodes.value() > 0:
|
||||
self.episode_idx.setValue(self.episode_idx.value() + self.num_episodes.value())
|
||||
|
||||
def main():
|
||||
app = QApplication(sys.argv)
|
||||
window = AlohaDataCollectionGUI()
|
||||
window.show()
|
||||
sys.exit(app.exec_())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,462 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from pprint import pprint
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
RemoteRobotConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
# init_keyboard_listener,
|
||||
record_episode,
|
||||
stop_recording,
|
||||
is_headless
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
import torch
|
||||
import rospy
|
||||
import cv2
|
||||
from lerobot.configs import parser
|
||||
from agilex_robot import AgilexRobot
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
def control_loop(
|
||||
robot,
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# if not robot.is_connected:
|
||||
# robot.connect()
|
||||
|
||||
if events is None:
|
||||
events = {"exit_early": False}
|
||||
|
||||
if control_time_s is None:
|
||||
control_time_s = float("inf")
|
||||
|
||||
if dataset is not None and single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
rate = rospy.Rate(fps)
|
||||
print_flag = True
|
||||
while timestamp < control_time_s and not rospy.is_shutdown():
|
||||
# print(timestamp < control_time_s)
|
||||
# print(rospy.is_shutdown())
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step()
|
||||
if observation is None or action is None:
|
||||
if print_flag:
|
||||
print("sync data fail, retrying...\n")
|
||||
print_flag = False
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
# pass
|
||||
observation = robot.capture_observation()
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if display_cameras and not is_headless():
|
||||
# image_keys = [key for key in observation if "image" in key]
|
||||
# for key in image_keys:
|
||||
# if "depth" in key:
|
||||
# pass
|
||||
# else:
|
||||
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
|
||||
# print(1)
|
||||
# cv2.waitKey(1)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整)
|
||||
screen_width = 1920
|
||||
screen_height = 1080
|
||||
|
||||
# 计算窗口的排列方式
|
||||
num_images = len(image_keys)
|
||||
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
|
||||
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
|
||||
columns = min(num_images, max_columns) # 实际使用的列数
|
||||
|
||||
# 遍历所有图像键并显示
|
||||
for idx, key in enumerate(image_keys):
|
||||
if "depth" in key:
|
||||
continue # 跳过深度图像
|
||||
|
||||
# 将图像从 RGB 转换为 BGR 格式
|
||||
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 创建窗口
|
||||
cv2.imshow(key, image)
|
||||
|
||||
# 计算窗口位置
|
||||
window_width = 640
|
||||
window_height = 480
|
||||
row = idx // max_columns
|
||||
col = idx % max_columns
|
||||
x_position = col * window_width
|
||||
y_position = row * window_height
|
||||
|
||||
# 移动窗口到指定位置
|
||||
cv2.moveWindow(key, x_position, y_position)
|
||||
|
||||
# 等待 1 毫秒以处理事件
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
# log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["record_start"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
events["record_start"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.up:
|
||||
print("Up arrow pressed. Start data recording...")
|
||||
events["record_start"] = True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def record_episode(
|
||||
robot,
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
policy,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
|
||||
def record(
|
||||
robot,
|
||||
cfg
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
)
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.num_image_writer_processes,
|
||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.repo_id,
|
||||
cfg.fps,
|
||||
root=cfg.root,
|
||||
robot=None,
|
||||
features=robot.features,
|
||||
use_videos=cfg.video,
|
||||
image_writer_processes=cfg.num_image_writer_processes,
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
# policy = None
|
||||
|
||||
# if not robot.is_connected:
|
||||
# robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
print()
|
||||
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
||||
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||||
|
||||
# if has_method(robot, "teleop_safety_stop"):
|
||||
# robot.teleop_safety_stop()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= cfg.num_episodes:
|
||||
break
|
||||
|
||||
# if events["record_start"]:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
||||
record_episode(
|
||||
robot=robot,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
policy=policy,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Current code logic doesn't allow to teleoperate during this time.
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
pprint("Reset the environment, stop recording")
|
||||
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
pprint("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
|
||||
|
||||
def replay(
|
||||
robot: AgilexRobot,
|
||||
cfg,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# if not robot.is_connected:
|
||||
# robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
# log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
import argparse
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.fps = 30
|
||||
args.resume = False
|
||||
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
|
||||
args.root = "./data5"
|
||||
args.episode = 0 # replay episode
|
||||
args.num_image_writer_processes = 0
|
||||
args.num_image_writer_threads_per_camera = 4
|
||||
args.video = True
|
||||
args.num_episodes = 100
|
||||
args.episode_time_s = 30000
|
||||
args.play_sounds = False
|
||||
args.display_cameras = True
|
||||
args.single_task = "move the bottle from the right to the scale right"
|
||||
args.use_depth_image = False
|
||||
args.use_base = False
|
||||
args.push_to_hub = False
|
||||
args.policy = None
|
||||
# args.teleoprate = True
|
||||
args.control_type = "record"
|
||||
# args.control_type = "replay"
|
||||
return args
|
||||
|
||||
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
def control_robot(cfg):
|
||||
|
||||
from rosrobot_factory import RobotFactory
|
||||
# 使用工厂模式创建机器人实例
|
||||
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg)
|
||||
|
||||
if cfg.control_type == "record":
|
||||
record(robot, cfg)
|
||||
elif cfg.control_type == "replay":
|
||||
replay(robot, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_arguments()
|
||||
control_robot(cfg)
|
||||
# control_robot()
|
||||
# 使用工厂模式创建机器人实例
|
||||
# robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
# print(robot.features.items())
|
||||
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
|
||||
# record(robot, cfg)
|
||||
# capture = robot.capture_observation()
|
||||
# import torch
|
||||
# torch.save(capture, "test.pt")
|
||||
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
|
||||
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
|
||||
# device='cpu')
|
||||
# robot.send_action(action.squeeze(0))
|
||||
# print()
|
||||
@@ -1,3 +0,0 @@
|
||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtiff.so.5
|
||||
# export LD_LIBRARY_PATH=/home/ubuntu/miniconda3/envs/lerobot/lib:$LD_LIBRARY_PATH
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
|
||||
@@ -1,769 +0,0 @@
|
||||
#!/home/lin/software/miniconda3/envs/aloha/bin/python
|
||||
# -- coding: UTF-8
|
||||
"""
|
||||
#!/usr/bin/python3
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
import pickle
|
||||
import argparse
|
||||
from einops import rearrange
|
||||
import collections
|
||||
from collections import deque
|
||||
|
||||
import rospy
|
||||
from std_msgs.msg import Header
|
||||
from geometry_msgs.msg import Twist
|
||||
from sensor_msgs.msg import JointState, Image
|
||||
from nav_msgs.msg import Odometry
|
||||
from cv_bridge import CvBridge
|
||||
import time
|
||||
import threading
|
||||
import math
|
||||
import threading
|
||||
|
||||
|
||||
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
|
||||
SEED = 42
|
||||
torch.manual_seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
|
||||
task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']}
|
||||
|
||||
inference_thread = None
|
||||
inference_lock = threading.Lock()
|
||||
inference_actions = None
|
||||
inference_timestep = None
|
||||
|
||||
|
||||
def actions_interpolation(args, pre_action, actions, stats):
|
||||
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
|
||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||
result = [pre_action]
|
||||
post_action = post_process(actions[0])
|
||||
# print("pre_action:", pre_action[7:])
|
||||
# print("actions_interpolation1:", post_action[:, 7:])
|
||||
max_diff_index = 0
|
||||
max_diff = -1
|
||||
for i in range(post_action.shape[0]):
|
||||
diff = 0
|
||||
for j in range(pre_action.shape[0]):
|
||||
if j == 6 or j == 13:
|
||||
continue
|
||||
diff += math.fabs(pre_action[j] - post_action[i][j])
|
||||
if diff > max_diff:
|
||||
max_diff = diff
|
||||
max_diff_index = i
|
||||
|
||||
for i in range(max_diff_index, post_action.shape[0]):
|
||||
step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])])
|
||||
inter = np.linspace(result[-1], post_action[i], step+2)
|
||||
result.extend(inter[1:])
|
||||
while len(result) < args.chunk_size+1:
|
||||
result.append(result[-1])
|
||||
result = np.array(result)[1:args.chunk_size+1]
|
||||
# print("actions_interpolation2:", result.shape, result[:, 7:])
|
||||
result = pre_process(result)
|
||||
result = result[np.newaxis, :]
|
||||
return result
|
||||
|
||||
|
||||
def get_model_config(args):
|
||||
# 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。
|
||||
set_seed(1)
|
||||
|
||||
# 如果是ACT策略
|
||||
# fixed parameters
|
||||
if args.policy_class == 'ACT':
|
||||
policy_config = {'lr': args.lr,
|
||||
'lr_backbone': args.lr_backbone,
|
||||
'backbone': args.backbone,
|
||||
'masks': args.masks,
|
||||
'weight_decay': args.weight_decay,
|
||||
'dilation': args.dilation,
|
||||
'position_embedding': args.position_embedding,
|
||||
'loss_function': args.loss_function,
|
||||
'chunk_size': args.chunk_size, # 查询
|
||||
'camera_names': task_config['camera_names'],
|
||||
'use_depth_image': args.use_depth_image,
|
||||
'use_robot_base': args.use_robot_base,
|
||||
'kl_weight': args.kl_weight, # kl散度权重
|
||||
'hidden_dim': args.hidden_dim, # 隐藏层维度
|
||||
'dim_feedforward': args.dim_feedforward,
|
||||
'enc_layers': args.enc_layers,
|
||||
'dec_layers': args.dec_layers,
|
||||
'nheads': args.nheads,
|
||||
'dropout': args.dropout,
|
||||
'pre_norm': args.pre_norm
|
||||
}
|
||||
elif args.policy_class == 'CNNMLP':
|
||||
policy_config = {'lr': args.lr,
|
||||
'lr_backbone': args.lr_backbone,
|
||||
'backbone': args.backbone,
|
||||
'masks': args.masks,
|
||||
'weight_decay': args.weight_decay,
|
||||
'dilation': args.dilation,
|
||||
'position_embedding': args.position_embedding,
|
||||
'loss_function': args.loss_function,
|
||||
'chunk_size': 1, # 查询
|
||||
'camera_names': task_config['camera_names'],
|
||||
'use_depth_image': args.use_depth_image,
|
||||
'use_robot_base': args.use_robot_base
|
||||
}
|
||||
|
||||
elif args.policy_class == 'Diffusion':
|
||||
policy_config = {'lr': args.lr,
|
||||
'lr_backbone': args.lr_backbone,
|
||||
'backbone': args.backbone,
|
||||
'masks': args.masks,
|
||||
'weight_decay': args.weight_decay,
|
||||
'dilation': args.dilation,
|
||||
'position_embedding': args.position_embedding,
|
||||
'loss_function': args.loss_function,
|
||||
'chunk_size': args.chunk_size, # 查询
|
||||
'camera_names': task_config['camera_names'],
|
||||
'use_depth_image': args.use_depth_image,
|
||||
'use_robot_base': args.use_robot_base,
|
||||
'observation_horizon': args.observation_horizon,
|
||||
'action_horizon': args.action_horizon,
|
||||
'num_inference_timesteps': args.num_inference_timesteps,
|
||||
'ema_power': args.ema_power
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
config = {
|
||||
'ckpt_dir': args.ckpt_dir,
|
||||
'ckpt_name': args.ckpt_name,
|
||||
'ckpt_stats_name': args.ckpt_stats_name,
|
||||
'episode_len': args.max_publish_step,
|
||||
'state_dim': args.state_dim,
|
||||
'policy_class': args.policy_class,
|
||||
'policy_config': policy_config,
|
||||
'temporal_agg': args.temporal_agg,
|
||||
'camera_names': task_config['camera_names'],
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def make_policy(policy_class, policy_config):
|
||||
if policy_class == 'ACT':
|
||||
policy = ACTPolicy(policy_config)
|
||||
elif policy_class == 'CNNMLP':
|
||||
policy = CNNMLPPolicy(policy_config)
|
||||
elif policy_class == 'Diffusion':
|
||||
policy = DiffusionPolicy(policy_config)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return policy
|
||||
|
||||
|
||||
def get_image(observation, camera_names):
|
||||
curr_images = []
|
||||
for cam_name in camera_names:
|
||||
curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w')
|
||||
|
||||
curr_images.append(curr_image)
|
||||
curr_image = np.stack(curr_images, axis=0)
|
||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
||||
return curr_image
|
||||
|
||||
|
||||
def get_depth_image(observation, camera_names):
|
||||
curr_images = []
|
||||
for cam_name in camera_names:
|
||||
curr_images.append(observation['images_depth'][cam_name])
|
||||
curr_image = np.stack(curr_images, axis=0)
|
||||
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
||||
return curr_image
|
||||
|
||||
|
||||
def inference_process(args, config, ros_operator, policy, stats, t, pre_action):
|
||||
global inference_lock
|
||||
global inference_actions
|
||||
global inference_timestep
|
||||
print_flag = True
|
||||
pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"]
|
||||
rate = rospy.Rate(args.publish_rate)
|
||||
while True and not rospy.is_shutdown():
|
||||
result = ros_operator.get_frame()
|
||||
if not result:
|
||||
if print_flag:
|
||||
print("syn fail")
|
||||
print_flag = False
|
||||
rate.sleep()
|
||||
continue
|
||||
print_flag = True
|
||||
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||
puppet_arm_left, puppet_arm_right, robot_base) = result
|
||||
obs = collections.OrderedDict()
|
||||
image_dict = dict()
|
||||
|
||||
image_dict[config['camera_names'][0]] = img_front
|
||||
image_dict[config['camera_names'][1]] = img_left
|
||||
image_dict[config['camera_names'][2]] = img_right
|
||||
|
||||
|
||||
obs['images'] = image_dict
|
||||
|
||||
if args.use_depth_image:
|
||||
image_depth_dict = dict()
|
||||
image_depth_dict[config['camera_names'][0]] = img_front_depth
|
||||
image_depth_dict[config['camera_names'][1]] = img_left_depth
|
||||
image_depth_dict[config['camera_names'][2]] = img_right_depth
|
||||
obs['images_depth'] = image_depth_dict
|
||||
|
||||
obs['qpos'] = np.concatenate(
|
||||
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
|
||||
obs['qvel'] = np.concatenate(
|
||||
(np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
|
||||
obs['effort'] = np.concatenate(
|
||||
(np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
|
||||
if args.use_robot_base:
|
||||
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
|
||||
obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0)
|
||||
else:
|
||||
obs['base_vel'] = [0.0, 0.0]
|
||||
# qpos_numpy = np.array(obs['qpos'])
|
||||
|
||||
# 归一化处理qpos 并转到cuda
|
||||
qpos = pre_pos_process(obs['qpos'])
|
||||
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
|
||||
# 当前图像curr_image获取图像
|
||||
curr_image = get_image(obs, config['camera_names'])
|
||||
curr_depth_image = None
|
||||
if args.use_depth_image:
|
||||
curr_depth_image = get_depth_image(obs, config['camera_names'])
|
||||
start_time = time.time()
|
||||
all_actions = policy(curr_image, curr_depth_image, qpos)
|
||||
end_time = time.time()
|
||||
print("model cost time: ", end_time -start_time)
|
||||
inference_lock.acquire()
|
||||
inference_actions = all_actions.cpu().detach().numpy()
|
||||
if pre_action is None:
|
||||
pre_action = obs['qpos']
|
||||
# print("obs['qpos']:", obs['qpos'][7:])
|
||||
if args.use_actions_interpolation:
|
||||
inference_actions = actions_interpolation(args, pre_action, inference_actions, stats)
|
||||
inference_timestep = t
|
||||
inference_lock.release()
|
||||
break
|
||||
|
||||
|
||||
def model_inference(args, config, ros_operator, save_episode=True):
|
||||
global inference_lock
|
||||
global inference_actions
|
||||
global inference_timestep
|
||||
global inference_thread
|
||||
set_seed(1000)
|
||||
|
||||
# 1 创建模型数据 继承nn.Module
|
||||
policy = make_policy(config['policy_class'], config['policy_config'])
|
||||
# print("model structure\n", policy.model)
|
||||
|
||||
# 2 加载模型权重
|
||||
ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name'])
|
||||
state_dict = torch.load(ckpt_path)
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]:
|
||||
continue
|
||||
if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]:
|
||||
continue
|
||||
new_state_dict[key] = value
|
||||
loading_status = policy.deserialize(new_state_dict)
|
||||
if not loading_status:
|
||||
print("ckpt path not exist")
|
||||
return False
|
||||
|
||||
# 3 模型设置为cuda模式和验证模式
|
||||
policy.cuda()
|
||||
policy.eval()
|
||||
|
||||
# 4 加载统计值
|
||||
stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name'])
|
||||
# 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维
|
||||
with open(stats_path, 'rb') as f:
|
||||
stats = pickle.load(f)
|
||||
|
||||
# 数据预处理和后处理函数定义
|
||||
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
|
||||
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
|
||||
|
||||
max_publish_step = config['episode_len']
|
||||
chunk_size = config['policy_config']['chunk_size']
|
||||
|
||||
# 发布基础的姿态
|
||||
left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
|
||||
right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
|
||||
left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
|
||||
right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
|
||||
|
||||
ros_operator.puppet_arm_publish_continuous(left0, right0)
|
||||
input("Enter any key to continue :")
|
||||
ros_operator.puppet_arm_publish_continuous(left1, right1)
|
||||
action = None
|
||||
# 推理
|
||||
with torch.inference_mode():
|
||||
while True and not rospy.is_shutdown():
|
||||
# 每个回合的步数
|
||||
t = 0
|
||||
max_t = 0
|
||||
rate = rospy.Rate(args.publish_rate)
|
||||
if config['temporal_agg']:
|
||||
all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']])
|
||||
while t < max_publish_step and not rospy.is_shutdown():
|
||||
# start_time = time.time()
|
||||
# query policy
|
||||
if config['policy_class'] == "ACT":
|
||||
if t >= max_t:
|
||||
pre_action = action
|
||||
inference_thread = threading.Thread(target=inference_process,
|
||||
args=(args, config, ros_operator,
|
||||
policy, stats, t, pre_action))
|
||||
inference_thread.start()
|
||||
inference_thread.join()
|
||||
inference_lock.acquire()
|
||||
if inference_actions is not None:
|
||||
inference_thread = None
|
||||
all_actions = inference_actions
|
||||
inference_actions = None
|
||||
max_t = t + args.pos_lookahead_step
|
||||
if config['temporal_agg']:
|
||||
all_time_actions[[t], t:t + chunk_size] = all_actions
|
||||
inference_lock.release()
|
||||
if config['temporal_agg']:
|
||||
actions_for_curr_step = all_time_actions[:, t]
|
||||
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
|
||||
actions_for_curr_step = actions_for_curr_step[actions_populated]
|
||||
k = 0.01
|
||||
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
|
||||
exp_weights = exp_weights / exp_weights.sum()
|
||||
exp_weights = exp_weights[:, np.newaxis]
|
||||
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True)
|
||||
else:
|
||||
if args.pos_lookahead_step != 0:
|
||||
raw_action = all_actions[:, t % args.pos_lookahead_step]
|
||||
else:
|
||||
raw_action = all_actions[:, t % chunk_size]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
action = post_process(raw_action[0])
|
||||
left_action = action[:7] # 取7维度
|
||||
right_action = action[7:14]
|
||||
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
|
||||
if args.use_robot_base:
|
||||
vel_action = action[14:16]
|
||||
ros_operator.robot_base_publish(vel_action)
|
||||
t += 1
|
||||
# end_time = time.time()
|
||||
# print("publish: ", t)
|
||||
# print("time:", end_time - start_time)
|
||||
# print("left_action:", left_action)
|
||||
# print("right_action:", right_action)
|
||||
rate.sleep()
|
||||
|
||||
|
||||
class RosOperator:
|
||||
def __init__(self, args):
|
||||
self.robot_base_deque = None
|
||||
self.puppet_arm_right_deque = None
|
||||
self.puppet_arm_left_deque = None
|
||||
self.img_front_deque = None
|
||||
self.img_right_deque = None
|
||||
self.img_left_deque = None
|
||||
self.img_front_depth_deque = None
|
||||
self.img_right_depth_deque = None
|
||||
self.img_left_depth_deque = None
|
||||
self.bridge = None
|
||||
self.puppet_arm_left_publisher = None
|
||||
self.puppet_arm_right_publisher = None
|
||||
self.robot_base_publisher = None
|
||||
self.puppet_arm_publish_thread = None
|
||||
self.puppet_arm_publish_lock = None
|
||||
self.args = args
|
||||
self.ctrl_state = False
|
||||
self.ctrl_state_lock = threading.Lock()
|
||||
self.init()
|
||||
self.init_ros()
|
||||
|
||||
def init(self):
|
||||
self.bridge = CvBridge()
|
||||
self.img_left_deque = deque()
|
||||
self.img_right_deque = deque()
|
||||
self.img_front_deque = deque()
|
||||
self.img_left_depth_deque = deque()
|
||||
self.img_right_depth_deque = deque()
|
||||
self.img_front_depth_deque = deque()
|
||||
self.puppet_arm_left_deque = deque()
|
||||
self.puppet_arm_right_deque = deque()
|
||||
self.robot_base_deque = deque()
|
||||
self.puppet_arm_publish_lock = threading.Lock()
|
||||
self.puppet_arm_publish_lock.acquire()
|
||||
|
||||
def puppet_arm_publish(self, left, right):
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
joint_state_msg.position = left
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = right
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
|
||||
def robot_base_publish(self, vel):
|
||||
vel_msg = Twist()
|
||||
vel_msg.linear.x = vel[0]
|
||||
vel_msg.linear.y = 0
|
||||
vel_msg.linear.z = 0
|
||||
vel_msg.angular.x = 0
|
||||
vel_msg.angular.y = 0
|
||||
vel_msg.angular.z = vel[1]
|
||||
self.robot_base_publisher.publish(vel_msg)
|
||||
|
||||
def puppet_arm_publish_continuous(self, left, right):
|
||||
rate = rospy.Rate(self.args.publish_rate)
|
||||
left_arm = None
|
||||
right_arm = None
|
||||
while True and not rospy.is_shutdown():
|
||||
if len(self.puppet_arm_left_deque) != 0:
|
||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||
if len(self.puppet_arm_right_deque) != 0:
|
||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||
if left_arm is None or right_arm is None:
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
break
|
||||
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
|
||||
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
|
||||
flag = True
|
||||
step = 0
|
||||
while flag and not rospy.is_shutdown():
|
||||
if self.puppet_arm_publish_lock.acquire(False):
|
||||
return
|
||||
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
|
||||
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
|
||||
flag = False
|
||||
for i in range(len(left)):
|
||||
if left_diff[i] < self.args.arm_steps_length[i]:
|
||||
left_arm[i] = left[i]
|
||||
else:
|
||||
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
|
||||
flag = True
|
||||
for i in range(len(right)):
|
||||
if right_diff[i] < self.args.arm_steps_length[i]:
|
||||
right_arm[i] = right[i]
|
||||
else:
|
||||
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
|
||||
flag = True
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
joint_state_msg.position = left_arm
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = right_arm
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
step += 1
|
||||
print("puppet_arm_publish_continuous:", step)
|
||||
rate.sleep()
|
||||
|
||||
def puppet_arm_publish_linear(self, left, right):
|
||||
num_step = 100
|
||||
rate = rospy.Rate(200)
|
||||
|
||||
left_arm = None
|
||||
right_arm = None
|
||||
|
||||
while True and not rospy.is_shutdown():
|
||||
if len(self.puppet_arm_left_deque) != 0:
|
||||
left_arm = list(self.puppet_arm_left_deque[-1].position)
|
||||
if len(self.puppet_arm_right_deque) != 0:
|
||||
right_arm = list(self.puppet_arm_right_deque[-1].position)
|
||||
if left_arm is None or right_arm is None:
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
traj_left_list = np.linspace(left_arm, left, num_step)
|
||||
traj_right_list = np.linspace(right_arm, right, num_step)
|
||||
|
||||
for i in range(len(traj_left_list)):
|
||||
traj_left = traj_left_list[i]
|
||||
traj_right = traj_right_list[i]
|
||||
traj_left[-1] = left[-1]
|
||||
traj_right[-1] = right[-1]
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
joint_state_msg.position = traj_left
|
||||
self.puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
joint_state_msg.position = traj_right
|
||||
self.puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
rate.sleep()
|
||||
|
||||
def puppet_arm_publish_continuous_thread(self, left, right):
|
||||
if self.puppet_arm_publish_thread is not None:
|
||||
self.puppet_arm_publish_lock.release()
|
||||
self.puppet_arm_publish_thread.join()
|
||||
self.puppet_arm_publish_lock.acquire(False)
|
||||
self.puppet_arm_publish_thread = None
|
||||
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
|
||||
self.puppet_arm_publish_thread.start()
|
||||
|
||||
def get_frame(self):
|
||||
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
|
||||
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
|
||||
return False
|
||||
if self.args.use_depth_image:
|
||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
|
||||
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
|
||||
else:
|
||||
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
|
||||
|
||||
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
|
||||
return False
|
||||
|
||||
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_deque.popleft()
|
||||
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_deque.popleft()
|
||||
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_deque.popleft()
|
||||
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
|
||||
|
||||
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
puppet_arm_left = self.puppet_arm_left_deque.popleft()
|
||||
|
||||
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
puppet_arm_right = self.puppet_arm_right_deque.popleft()
|
||||
|
||||
img_left_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_left_depth_deque.popleft()
|
||||
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
|
||||
|
||||
img_right_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_right_depth_deque.popleft()
|
||||
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
|
||||
|
||||
img_front_depth = None
|
||||
if self.args.use_depth_image:
|
||||
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.img_front_depth_deque.popleft()
|
||||
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
|
||||
|
||||
robot_base = None
|
||||
if self.args.use_robot_base:
|
||||
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
|
||||
self.robot_base_deque.popleft()
|
||||
robot_base = self.robot_base_deque.popleft()
|
||||
|
||||
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
||||
puppet_arm_left, puppet_arm_right, robot_base)
|
||||
|
||||
def img_left_callback(self, msg):
|
||||
if len(self.img_left_deque) >= 2000:
|
||||
self.img_left_deque.popleft()
|
||||
self.img_left_deque.append(msg)
|
||||
|
||||
def img_right_callback(self, msg):
|
||||
if len(self.img_right_deque) >= 2000:
|
||||
self.img_right_deque.popleft()
|
||||
self.img_right_deque.append(msg)
|
||||
|
||||
def img_front_callback(self, msg):
|
||||
if len(self.img_front_deque) >= 2000:
|
||||
self.img_front_deque.popleft()
|
||||
self.img_front_deque.append(msg)
|
||||
|
||||
def img_left_depth_callback(self, msg):
|
||||
if len(self.img_left_depth_deque) >= 2000:
|
||||
self.img_left_depth_deque.popleft()
|
||||
self.img_left_depth_deque.append(msg)
|
||||
|
||||
def img_right_depth_callback(self, msg):
|
||||
if len(self.img_right_depth_deque) >= 2000:
|
||||
self.img_right_depth_deque.popleft()
|
||||
self.img_right_depth_deque.append(msg)
|
||||
|
||||
def img_front_depth_callback(self, msg):
|
||||
if len(self.img_front_depth_deque) >= 2000:
|
||||
self.img_front_depth_deque.popleft()
|
||||
self.img_front_depth_deque.append(msg)
|
||||
|
||||
def puppet_arm_left_callback(self, msg):
|
||||
if len(self.puppet_arm_left_deque) >= 2000:
|
||||
self.puppet_arm_left_deque.popleft()
|
||||
self.puppet_arm_left_deque.append(msg)
|
||||
|
||||
def puppet_arm_right_callback(self, msg):
|
||||
if len(self.puppet_arm_right_deque) >= 2000:
|
||||
self.puppet_arm_right_deque.popleft()
|
||||
self.puppet_arm_right_deque.append(msg)
|
||||
|
||||
def robot_base_callback(self, msg):
|
||||
if len(self.robot_base_deque) >= 2000:
|
||||
self.robot_base_deque.popleft()
|
||||
self.robot_base_deque.append(msg)
|
||||
|
||||
def ctrl_callback(self, msg):
|
||||
self.ctrl_state_lock.acquire()
|
||||
self.ctrl_state = msg.data
|
||||
self.ctrl_state_lock.release()
|
||||
|
||||
def get_ctrl_state(self):
|
||||
self.ctrl_state_lock.acquire()
|
||||
state = self.ctrl_state
|
||||
self.ctrl_state_lock.release()
|
||||
return state
|
||||
|
||||
def init_ros(self):
|
||||
rospy.init_node('joint_state_publisher', anonymous=True)
|
||||
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
|
||||
if self.args.use_depth_image:
|
||||
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
|
||||
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
|
||||
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
|
||||
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
|
||||
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
|
||||
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False)
|
||||
parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False)
|
||||
parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False)
|
||||
parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False)
|
||||
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False)
|
||||
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False)
|
||||
parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False)
|
||||
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False)
|
||||
parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False)
|
||||
parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False)
|
||||
parser.add_argument('--dilation', action='store_true',
|
||||
help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False)
|
||||
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
||||
help="Type of positional embedding to use on top of the image features", required=False)
|
||||
parser.add_argument('--masks', action='store_true',
|
||||
help="Train segmentation head if the flag is provided")
|
||||
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False)
|
||||
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False)
|
||||
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False)
|
||||
parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False)
|
||||
|
||||
parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False)
|
||||
parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False)
|
||||
parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False)
|
||||
parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False)
|
||||
parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False)
|
||||
parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False)
|
||||
parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False)
|
||||
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False)
|
||||
parser.add_argument('--pre_norm', action='store_true', required=False)
|
||||
|
||||
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
|
||||
default='/camera_f/color/image_raw', required=False)
|
||||
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
|
||||
default='/camera_l/color/image_raw', required=False)
|
||||
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
|
||||
default='/camera_r/color/image_raw', required=False)
|
||||
|
||||
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
|
||||
default='/camera_f/depth/image_raw', required=False)
|
||||
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
|
||||
default='/camera_l/depth/image_raw', required=False)
|
||||
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
|
||||
default='/camera_r/depth/image_raw', required=False)
|
||||
|
||||
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
|
||||
default='/master/joint_left', required=False)
|
||||
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
|
||||
default='/master/joint_right', required=False)
|
||||
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
|
||||
default='/puppet/joint_left', required=False)
|
||||
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
|
||||
default='/puppet/joint_right', required=False)
|
||||
|
||||
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
|
||||
default='/odom_raw', required=False)
|
||||
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
|
||||
default='/cmd_vel', required=False)
|
||||
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
|
||||
default=False, required=False)
|
||||
parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate',
|
||||
default=40, required=False)
|
||||
parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step',
|
||||
default=0, required=False)
|
||||
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size',
|
||||
default=32, required=False)
|
||||
parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length',
|
||||
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
|
||||
|
||||
parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation',
|
||||
default=False, required=False)
|
||||
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
|
||||
default=False, required=False)
|
||||
|
||||
# for Diffusion
|
||||
parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False)
|
||||
parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False)
|
||||
parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False)
|
||||
parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_arguments()
|
||||
ros_operator = RosOperator(args)
|
||||
config = get_model_config(args)
|
||||
model_inference(args, config, ros_operator, save_episode=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# python act/inference.py --ckpt_dir ~/train0314/
|
||||
@@ -1,33 +0,0 @@
|
||||
import pandas as pd
|
||||
|
||||
def read_and_print_parquet_row(file_path, row_index=0):
|
||||
"""
|
||||
读取Parquet文件并打印指定行的数据
|
||||
|
||||
参数:
|
||||
file_path (str): Parquet文件路径
|
||||
row_index (int): 要打印的行索引(默认为第0行)
|
||||
"""
|
||||
try:
|
||||
# 读取Parquet文件
|
||||
df = pd.read_parquet(file_path)
|
||||
|
||||
# 检查行索引是否有效
|
||||
if row_index >= len(df):
|
||||
print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)")
|
||||
return
|
||||
|
||||
# 打印指定行数据
|
||||
print(f"文件: {file_path}")
|
||||
print(f"第 {row_index} 行数据:\n{'-'*30}")
|
||||
print(df.iloc[row_index])
|
||||
|
||||
except FileNotFoundError:
|
||||
print(f"错误: 文件 {file_path} 不存在")
|
||||
except Exception as e:
|
||||
print(f"读取失败: {str(e)}")
|
||||
|
||||
# 示例用法
|
||||
if __name__ == "__main__":
|
||||
file_path = "example.parquet" # 替换为你的Parquet文件路径
|
||||
read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行
|
||||
207
collect_data/replay_data.py
Normal file → Executable file
207
collect_data/replay_data.py
Normal file → Executable file
@@ -10,24 +10,63 @@ from cv_bridge import CvBridge
|
||||
from std_msgs.msg import Header
|
||||
from sensor_msgs.msg import Image, JointState
|
||||
from geometry_msgs.msg import Twist
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
import sys
|
||||
sys.path.append("./")
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
compressed = root.attrs.get('compress', False)
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
if 'effort' in root.keys():
|
||||
effort = root['/observations/effort'][()]
|
||||
else:
|
||||
effort = None
|
||||
action = root['/action'][()]
|
||||
base_action = root['/base_action'][()]
|
||||
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
|
||||
if compressed:
|
||||
compress_len = root['/compress_len'][()]
|
||||
|
||||
if compressed:
|
||||
for cam_id, cam_name in enumerate(image_dict.keys()):
|
||||
# un-pad and uncompress
|
||||
padded_compressed_image_list = image_dict[cam_name]
|
||||
image_list = []
|
||||
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
|
||||
image_len = int(compress_len[cam_id, frame_id])
|
||||
|
||||
compressed_image = padded_compressed_image
|
||||
image = cv2.imdecode(compressed_image, 1)
|
||||
image_list.append(image)
|
||||
image_dict[cam_name] = image_list
|
||||
|
||||
return qpos, qvel, effort, action, base_action, image_dict
|
||||
|
||||
def main(args):
|
||||
rospy.init_node("replay_node")
|
||||
bridge = CvBridge()
|
||||
# img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
|
||||
# img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
|
||||
# img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
|
||||
img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
|
||||
img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
|
||||
img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
|
||||
|
||||
# puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
|
||||
# puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
|
||||
puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
|
||||
puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
|
||||
|
||||
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
|
||||
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
|
||||
|
||||
# robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
|
||||
robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
|
||||
|
||||
|
||||
# dataset_dir = args.dataset_dir
|
||||
@@ -35,78 +74,130 @@ def main(args):
|
||||
# task_name = args.task_name
|
||||
# dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
velocitys = dataset.hf_dataset.select_columns("observation.velocity")
|
||||
efforts = dataset.hf_dataset.select_columns("observation.effort")
|
||||
|
||||
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
|
||||
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
|
||||
|
||||
|
||||
joint_state_msg = JointState()
|
||||
joint_state_msg.header = Header()
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称
|
||||
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
||||
twist_msg = Twist()
|
||||
|
||||
rate = rospy.Rate(args.fps)
|
||||
|
||||
rate = rospy.Rate(args.frame_rate)
|
||||
|
||||
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
|
||||
|
||||
actions = [[ 0.1277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.3238, -0.1094,
|
||||
0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]]
|
||||
|
||||
last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647]
|
||||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
||||
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
||||
rate = rospy.Rate(50)
|
||||
for idx in range(len(actions)):
|
||||
action = actions[idx]['action'].detach().cpu().numpy()
|
||||
velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy()
|
||||
effort = efforts[idx]['observation.effort'].detach().cpu().numpy()
|
||||
if(rospy.is_shutdown()):
|
||||
break
|
||||
|
||||
new_actions = np.linspace(last_action, action, 5) # 插值
|
||||
new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值
|
||||
new_efforts = np.linspace(last_effort, effort, 5) # 插值
|
||||
last_action = action
|
||||
last_velocity = velocity
|
||||
last_effort = effort
|
||||
for act in new_actions:
|
||||
print(np.round(act[:7], 4))
|
||||
cur_timestamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.header.stamp = cur_timestamp
|
||||
if not args.only_pub_master:
|
||||
last_action = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279, 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
|
||||
rate = rospy.Rate(100)
|
||||
for action in actions:
|
||||
if(rospy.is_shutdown()):
|
||||
break
|
||||
|
||||
joint_state_msg.position = act[:7]
|
||||
joint_state_msg.velocity = last_velocity[:7]
|
||||
joint_state_msg.effort = last_effort[:7]
|
||||
new_actions = np.linspace(last_action, action, 50) # 插值
|
||||
last_action = action
|
||||
for act in new_actions:
|
||||
print(np.round(act[:7], 4))
|
||||
cur_timestamp = rospy.Time.now() # 设置时间戳
|
||||
joint_state_msg.header.stamp = cur_timestamp
|
||||
|
||||
joint_state_msg.position = act[:7]
|
||||
master_arm_left_publisher.publish(joint_state_msg)
|
||||
|
||||
joint_state_msg.position = act[7:]
|
||||
master_arm_right_publisher.publish(joint_state_msg)
|
||||
|
||||
if(rospy.is_shutdown()):
|
||||
break
|
||||
rate.sleep()
|
||||
|
||||
else:
|
||||
i = 0
|
||||
while(not rospy.is_shutdown() and i < len(actions)):
|
||||
print("left: ", np.round(qposs[i][:7], 4), " right: ", np.round(qposs[i][7:], 4))
|
||||
|
||||
cam_names = [k for k in image_dicts.keys()]
|
||||
image0 = image_dicts[cam_names[0]][i]
|
||||
image0 = image0[:, :, [2, 1, 0]] # swap B and R channel
|
||||
|
||||
image1 = image_dicts[cam_names[1]][i]
|
||||
image1 = image1[:, :, [2, 1, 0]] # swap B and R channel
|
||||
|
||||
image2 = image_dicts[cam_names[2]][i]
|
||||
image2 = image2[:, :, [2, 1, 0]] # swap B and R channel
|
||||
|
||||
cur_timestamp = rospy.Time.now() # 设置时间戳
|
||||
|
||||
joint_state_msg.header.stamp = cur_timestamp
|
||||
joint_state_msg.position = actions[i][:7]
|
||||
master_arm_left_publisher.publish(joint_state_msg)
|
||||
|
||||
joint_state_msg.position = act[7:]
|
||||
joint_state_msg.velocity = last_velocity[:7]
|
||||
joint_state_msg.effort = last_effort[7:]
|
||||
joint_state_msg.position = actions[i][7:]
|
||||
master_arm_right_publisher.publish(joint_state_msg)
|
||||
|
||||
if(rospy.is_shutdown()):
|
||||
break
|
||||
rate.sleep()
|
||||
joint_state_msg.position = qposs[i][:7]
|
||||
puppet_arm_left_publisher.publish(joint_state_msg)
|
||||
|
||||
joint_state_msg.position = qposs[i][7:]
|
||||
puppet_arm_right_publisher.publish(joint_state_msg)
|
||||
|
||||
img_front_publisher.publish(bridge.cv2_to_imgmsg(image0, "bgr8"))
|
||||
img_left_publisher.publish(bridge.cv2_to_imgmsg(image1, "bgr8"))
|
||||
img_right_publisher.publish(bridge.cv2_to_imgmsg(image2, "bgr8"))
|
||||
|
||||
|
||||
twist_msg.linear.x = base_actions[i][0]
|
||||
twist_msg.angular.z = base_actions[i][1]
|
||||
robot_base_publisher.publish(twist_msg)
|
||||
|
||||
i += 1
|
||||
rate.sleep()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
|
||||
# default='/master/joint_left', required=False)
|
||||
# parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
|
||||
# default='/master/joint_right', required=False)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args.repo_id = "tangger/test"
|
||||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
|
||||
args.episode = 1 # replay episode
|
||||
args.master_arm_left_topic = "/master/joint_left"
|
||||
args.master_arm_right_topic = "/master/joint_right"
|
||||
args.fps = 30
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=False)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
|
||||
default="aloha_mobile_dummy", required=False)
|
||||
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
|
||||
|
||||
parser.add_argument('--camera_names', action='store', type=str, help='camera_names',
|
||||
default=['cam_high', 'cam_left_wrist', 'cam_right_wrist'], required=False)
|
||||
|
||||
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
|
||||
default='/camera_f/color/image_raw', required=False)
|
||||
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
|
||||
default='/camera_l/color/image_raw', required=False)
|
||||
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
|
||||
default='/camera_r/color/image_raw', required=False)
|
||||
|
||||
parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
|
||||
default='/master/joint_left', required=False)
|
||||
parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
|
||||
default='/master/joint_right', required=False)
|
||||
|
||||
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
|
||||
default='/puppet/joint_left', required=False)
|
||||
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
|
||||
default='/puppet/joint_right', required=False)
|
||||
|
||||
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
|
||||
default='/cmd_vel', required=False)
|
||||
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
|
||||
default=False, required=False)
|
||||
|
||||
parser.add_argument('--frame_rate', action='store', type=int, help='frame_rate',
|
||||
default=30, required=False)
|
||||
|
||||
parser.add_argument('--only_pub_master', action='store_true', help='only_pub_master',required=False)
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0
|
||||
10
collect_data/requirements.txt
Executable file
10
collect_data/requirements.txt
Executable file
@@ -0,0 +1,10 @@
|
||||
opencv-python==4.9.0.80
|
||||
matplotlib==3.7.5
|
||||
h5py==3.8.0
|
||||
dm-env==1.6
|
||||
numpy==1.23.4
|
||||
pyyaml
|
||||
rospkg==1.5.0
|
||||
catkin-pkg==1.0.0
|
||||
empy==3.3.4
|
||||
PyQt5==5.15.10
|
||||
@@ -1,70 +0,0 @@
|
||||
from lerobot.common.policies.act.modeling_act import ACTPolicy
|
||||
from lerobot.common.robot_devices.utils import busy_wait
|
||||
import time
|
||||
import argparse
|
||||
from agilex_robot import AgilexRobot
|
||||
import torch
|
||||
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.fps = 30
|
||||
args.resume = False
|
||||
args.repo_id = "tangger/test"
|
||||
args.root = "./data2"
|
||||
args.num_image_writer_processes = 0
|
||||
args.num_image_writer_threads_per_camera = 4
|
||||
args.video = True
|
||||
args.num_episodes = 50
|
||||
args.episode_time_s = 30000
|
||||
args.play_sounds = False
|
||||
args.display_cameras = True
|
||||
args.single_task = "test test"
|
||||
args.use_depth_image = False
|
||||
args.use_base = False
|
||||
args.push_to_hub = False
|
||||
args.policy= None
|
||||
args.teleoprate = False
|
||||
return args
|
||||
|
||||
|
||||
cfg = get_arguments()
|
||||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
inference_time_s = 360
|
||||
fps = 30
|
||||
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
|
||||
|
||||
ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
|
||||
policy = ACTPolicy.from_pretrained(ckpt_path)
|
||||
policy.to(device)
|
||||
|
||||
for _ in range(inference_time_s * fps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Read the follower state and access the frames from the cameras
|
||||
observation = robot.capture_observation()
|
||||
if observation is None:
|
||||
print("Observation is None, skipping...")
|
||||
continue
|
||||
|
||||
# Convert to pytorch format: channel first and float32 in [0,1]
|
||||
# with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
# Order the robot to move
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
busy_wait(1 / fps - dt_s)
|
||||
272
collect_data/utils.py
Normal file
272
collect_data/utils.py
Normal file
@@ -0,0 +1,272 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import h5py
|
||||
import time
|
||||
|
||||
|
||||
|
||||
def display_camera_grid(image_dict, grid_shape=None, window_name="MindRobot-V1 Data Collection", scale=1.0):
|
||||
"""
|
||||
显示多摄像头画面(保持原始比例,但可整体缩放)
|
||||
|
||||
参数:
|
||||
image_dict: {摄像头名称: 图像numpy数组}
|
||||
grid_shape: (行, 列) 布局,None自动计算
|
||||
window_name: 窗口名称
|
||||
scale: 整体显示缩放比例(0.5表示显示为原尺寸的50%)
|
||||
"""
|
||||
# 输入验证和数据处理(保持原代码不变)
|
||||
if not isinstance(image_dict, dict):
|
||||
raise TypeError("输入必须是字典类型")
|
||||
|
||||
valid_data = []
|
||||
for name, img in image_dict.items():
|
||||
if not isinstance(img, np.ndarray):
|
||||
continue
|
||||
if img.dtype != np.uint8:
|
||||
img = img.astype(np.uint8)
|
||||
if img.ndim == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif img.shape[2] == 4:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
|
||||
elif img.shape[2] == 3:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
valid_data.append((name, img))
|
||||
|
||||
if not valid_data:
|
||||
print("错误: 没有有效的图像可显示!")
|
||||
return None
|
||||
|
||||
# 自动计算网格布局
|
||||
num_valid = len(valid_data)
|
||||
if grid_shape is None:
|
||||
grid_shape = (1, num_valid) if num_valid <= 3 else (2, int(np.ceil(num_valid/2)))
|
||||
|
||||
rows, cols = grid_shape
|
||||
|
||||
# 计算每行/列的最大尺寸
|
||||
row_heights = [0]*rows
|
||||
col_widths = [0]*cols
|
||||
|
||||
for i, (_, img) in enumerate(valid_data[:rows*cols]):
|
||||
r, c = i//cols, i%cols
|
||||
row_heights[r] = max(row_heights[r], img.shape[0])
|
||||
col_widths[c] = max(col_widths[c], img.shape[1])
|
||||
|
||||
# 计算画布总尺寸(应用整体缩放)
|
||||
canvas_h = int(sum(row_heights) * scale)
|
||||
canvas_w = int(sum(col_widths) * scale)
|
||||
|
||||
# 创建画布
|
||||
canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)
|
||||
|
||||
# 计算每个子画面的显示区域
|
||||
row_pos = [0] + [int(sum(row_heights[:i+1])*scale) for i in range(rows)]
|
||||
col_pos = [0] + [int(sum(col_widths[:i+1])*scale) for i in range(cols)]
|
||||
|
||||
# 填充图像
|
||||
for i, (name, img) in enumerate(valid_data[:rows*cols]):
|
||||
r, c = i//cols, i%cols
|
||||
|
||||
# 计算当前图像的显示区域
|
||||
x1, x2 = col_pos[c], col_pos[c+1]
|
||||
y1, y2 = row_pos[r], row_pos[r+1]
|
||||
|
||||
# 计算当前图像的缩放后尺寸
|
||||
display_h = int(img.shape[0] * scale)
|
||||
display_w = int(img.shape[1] * scale)
|
||||
|
||||
# 缩放图像(保持比例)
|
||||
resized_img = cv2.resize(img, (display_w, display_h))
|
||||
|
||||
# 放置到画布
|
||||
canvas[y1:y1+display_h, x1:x1+display_w] = resized_img
|
||||
|
||||
# 添加标签(按比例缩放字体)
|
||||
font_scale = 0.8 *scale
|
||||
thickness = max(2, int(2 * scale))
|
||||
cv2.putText(canvas, name, (x1+10, y1+30),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255,255,255), thickness)
|
||||
|
||||
# 显示窗口(自动适应屏幕)
|
||||
cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)
|
||||
cv2.imshow(window_name, canvas)
|
||||
cv2.resizeWindow(window_name, canvas_w, canvas_h)
|
||||
cv2.waitKey(1)
|
||||
|
||||
return canvas
|
||||
|
||||
|
||||
# 保存数据函数
|
||||
def save_data(args, timesteps, actions, dataset_path):
|
||||
# 数据字典
|
||||
data_size = len(actions)
|
||||
data_dict = {
|
||||
# 一个是奖励里面的qpos,qvel, effort ,一个是实际发的acition
|
||||
'/observations/qpos': [],
|
||||
'/observations/qvel': [],
|
||||
'/observations/effort': [],
|
||||
'/action': [],
|
||||
'/base_action': [],
|
||||
# '/base_action_t265': [],
|
||||
}
|
||||
|
||||
# 相机字典 观察的图像
|
||||
for cam_name in args.camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'] = []
|
||||
if args.use_depth_image:
|
||||
data_dict[f'/observations/images_depth/{cam_name}'] = []
|
||||
|
||||
# len(action): max_timesteps, len(time_steps): max_timesteps + 1
|
||||
# 动作长度 遍历动作
|
||||
while actions:
|
||||
# 循环弹出一个队列
|
||||
action = actions.pop(0) # 动作 当前动作
|
||||
ts = timesteps.pop(0) # 奖励 前一帧
|
||||
|
||||
# 往字典里面添值
|
||||
# Timestep返回的qpos,qvel,effort
|
||||
data_dict['/observations/qpos'].append(ts.observation['qpos'])
|
||||
data_dict['/observations/qvel'].append(ts.observation['qvel'])
|
||||
data_dict['/observations/effort'].append(ts.observation['effort'])
|
||||
|
||||
# 实际发的action
|
||||
data_dict['/action'].append(action)
|
||||
data_dict['/base_action'].append(ts.observation['base_vel'])
|
||||
|
||||
# 相机数据
|
||||
# data_dict['/base_action_t265'].append(ts.observation['base_vel_t265'])
|
||||
for cam_name in args.camera_names:
|
||||
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])
|
||||
if args.use_depth_image:
|
||||
data_dict[f'/observations/images_depth/{cam_name}'].append(ts.observation['images_depth'][cam_name])
|
||||
|
||||
t0 = time.time()
|
||||
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
|
||||
# 文本的属性:
|
||||
# 1 是否仿真
|
||||
# 2 图像是否压缩
|
||||
#
|
||||
root.attrs['sim'] = False
|
||||
root.attrs['compress'] = False
|
||||
|
||||
# 创建一个新的组observations,观测状态组
|
||||
# 图像组
|
||||
obs = root.create_group('observations')
|
||||
image = obs.create_group('images')
|
||||
for cam_name in args.camera_names:
|
||||
_ = image.create_dataset(cam_name, (data_size, 480, 640, 3), dtype='uint8',
|
||||
chunks=(1, 480, 640, 3), )
|
||||
if args.use_depth_image:
|
||||
image_depth = obs.create_group('images_depth')
|
||||
for cam_name in args.camera_names:
|
||||
_ = image_depth.create_dataset(cam_name, (data_size, 480, 640), dtype='uint16',
|
||||
chunks=(1, 480, 640), )
|
||||
|
||||
_ = obs.create_dataset('qpos', (data_size, 14))
|
||||
_ = obs.create_dataset('qvel', (data_size, 14))
|
||||
_ = obs.create_dataset('effort', (data_size, 14))
|
||||
_ = root.create_dataset('action', (data_size, 14))
|
||||
_ = root.create_dataset('base_action', (data_size, 2))
|
||||
|
||||
# data_dict write into h5py.File
|
||||
for name, array in data_dict.items():
|
||||
root[name][...] = array
|
||||
print(f'\033[32m\nSaving: {time.time() - t0:.1f} secs. %s \033[0m\n'%dataset_path)
|
||||
|
||||
|
||||
def is_headless():
|
||||
"""
|
||||
Check if the environment is headless (no display available).
|
||||
|
||||
Returns:
|
||||
bool: True if the environment is headless, False otherwise.
|
||||
"""
|
||||
try:
|
||||
import tkinter as tk
|
||||
root = tk.Tk()
|
||||
root.withdraw()
|
||||
root.update()
|
||||
root.destroy()
|
||||
return False
|
||||
except:
|
||||
return True
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
"""
|
||||
Initialize keyboard listener for control events with new key mappings:
|
||||
- Left arrow: Start data recording
|
||||
- Right arrow: Save current data
|
||||
- Down arrow: Discard current data
|
||||
- Up arrow: Replay current data
|
||||
- ESC: Early termination
|
||||
|
||||
Returns:
|
||||
tuple: (listener, events) - Keyboard listener and events dictionary
|
||||
"""
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"record_start": False,
|
||||
"save_data": False,
|
||||
"discard_data": False,
|
||||
"replay_data": False
|
||||
}
|
||||
|
||||
if is_headless():
|
||||
print(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
return None, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.left:
|
||||
print("← Left arrow: STARTING data recording...")
|
||||
events.update({
|
||||
"record_start": True,
|
||||
"exit_early": False,
|
||||
"save_data": False,
|
||||
"discard_data": False
|
||||
})
|
||||
|
||||
elif key == keyboard.Key.right:
|
||||
print("→ Right arrow: SAVING current data...")
|
||||
events.update({
|
||||
"save_data": True,
|
||||
"exit_early": False,
|
||||
"record_start": False
|
||||
})
|
||||
|
||||
elif key == keyboard.Key.down:
|
||||
print("↓ Down arrow: DISCARDING current data...")
|
||||
events.update({
|
||||
"discard_data": True,
|
||||
"exit_early": False,
|
||||
"record_start": False
|
||||
})
|
||||
|
||||
elif key == keyboard.Key.up:
|
||||
print("↑ Up arrow: REPLAYING current data...")
|
||||
events.update({
|
||||
"replay_data": True,
|
||||
"exit_early": False
|
||||
})
|
||||
|
||||
elif key == keyboard.Key.esc:
|
||||
print("ESC: EARLY TERMINATION requested")
|
||||
events.update({
|
||||
"exit_early": True,
|
||||
"record_start": False
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
159
collect_data/visualize_episodes.py
Executable file
159
collect_data/visualize_episodes.py
Executable file
@@ -0,0 +1,159 @@
|
||||
#coding=utf-8
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import h5py
|
||||
import argparse
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
DT = 0.02
|
||||
# JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
|
||||
JOINT_NAMES = ["joint0", "joint1", "joint2", "joint3", "joint4", "joint5"]
|
||||
STATE_NAMES = JOINT_NAMES + ["gripper"]
|
||||
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]
|
||||
|
||||
def load_hdf5(dataset_dir, dataset_name):
|
||||
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
|
||||
if not os.path.isfile(dataset_path):
|
||||
print(f'Dataset does not exist at \n{dataset_path}\n')
|
||||
exit()
|
||||
|
||||
with h5py.File(dataset_path, 'r') as root:
|
||||
is_sim = root.attrs['sim']
|
||||
compressed = root.attrs.get('compress', False)
|
||||
qpos = root['/observations/qpos'][()]
|
||||
qvel = root['/observations/qvel'][()]
|
||||
if 'effort' in root.keys():
|
||||
effort = root['/observations/effort'][()]
|
||||
else:
|
||||
effort = None
|
||||
action = root['/action'][()]
|
||||
base_action = root['/base_action'][()]
|
||||
image_dict = dict()
|
||||
for cam_name in root[f'/observations/images/'].keys():
|
||||
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]
|
||||
if compressed:
|
||||
compress_len = root['/compress_len'][()]
|
||||
|
||||
if compressed:
|
||||
for cam_id, cam_name in enumerate(image_dict.keys()):
|
||||
# un-pad and uncompress
|
||||
padded_compressed_image_list = image_dict[cam_name]
|
||||
image_list = []
|
||||
for frame_id, padded_compressed_image in enumerate(padded_compressed_image_list): # [:1000] to save memory
|
||||
image_len = int(compress_len[cam_id, frame_id])
|
||||
compressed_image = padded_compressed_image
|
||||
image = cv2.imdecode(compressed_image, 1)
|
||||
image_list.append(image)
|
||||
image_dict[cam_name] = image_list
|
||||
|
||||
return qpos, qvel, effort, action, base_action, image_dict
|
||||
|
||||
def main(args):
|
||||
dataset_dir = args['dataset_dir']
|
||||
episode_idx = args['episode_idx']
|
||||
task_name = args['task_name']
|
||||
dataset_name = f'episode_{episode_idx}'
|
||||
|
||||
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
|
||||
|
||||
print('hdf5 loaded!!')
|
||||
|
||||
save_videos(image_dict, action, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
|
||||
|
||||
|
||||
|
||||
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
|
||||
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
|
||||
|
||||
def save_videos(video, actions, dt, video_path=None):
|
||||
cam_names = list(video.keys())
|
||||
all_cam_videos = []
|
||||
for cam_name in cam_names:
|
||||
all_cam_videos.append(video[cam_name])
|
||||
all_cam_videos = np.concatenate(all_cam_videos, axis=2) # width dimension
|
||||
|
||||
n_frames, h, w, _ = all_cam_videos.shape
|
||||
fps = int(1 / dt)
|
||||
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
|
||||
for t in range(n_frames):
|
||||
image = all_cam_videos[t]
|
||||
image = image[:, :, [2, 1, 0]] # swap B and R channel
|
||||
cv2.imshow("images",image)
|
||||
cv2.waitKey(30)
|
||||
print("episode_id: ", t, "left: ", np.round(actions[t][:7], 3), "right: ", np.round(actions[t][7:], 3), "\n")
|
||||
out.write(image)
|
||||
out.release()
|
||||
print(f'Saved video to: {video_path}')
|
||||
|
||||
|
||||
def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_overwrite=None):
|
||||
if label_overwrite:
|
||||
label1, label2 = label_overwrite
|
||||
else:
|
||||
label1, label2 = 'State', 'Command'
|
||||
|
||||
qpos = np.array(qpos_list) # ts, dim
|
||||
command = np.array(command_list)
|
||||
|
||||
num_ts, num_dim = qpos.shape
|
||||
h, w = 2, num_dim
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
|
||||
|
||||
# plot joint state
|
||||
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(qpos[:, dim_idx], label=label1, color='orangered')
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
# plot arm command
|
||||
# for dim_idx in range(num_dim):
|
||||
# ax = axs[dim_idx]
|
||||
# ax.plot(command[:, dim_idx], label=label2)
|
||||
# ax.legend()
|
||||
|
||||
if ylim:
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.set_ylim(ylim)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved qpos plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
|
||||
def visualize_base(readings, plot_path=None):
|
||||
readings = np.array(readings) # ts, dim
|
||||
num_ts, num_dim = readings.shape
|
||||
num_figs = num_dim
|
||||
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))
|
||||
|
||||
# plot joint state
|
||||
all_names = BASE_STATE_NAMES
|
||||
for dim_idx in range(num_dim):
|
||||
ax = axs[dim_idx]
|
||||
ax.plot(readings[:, dim_idx], label='raw')
|
||||
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
|
||||
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
|
||||
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
|
||||
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
|
||||
ax.legend()
|
||||
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(plot_path)
|
||||
print(f'Saved effort plot to: {plot_path}')
|
||||
plt.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
|
||||
parser.add_argument('--task_name', action='store', type=str, help='Task name.',
|
||||
default="aloha_mobile_dummy", required=False)
|
||||
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.',default=0, required=False)
|
||||
|
||||
main(vars(parser.parse_args()))
|
||||
Reference in New Issue
Block a user