使用原生的数据搜集代码

This commit is contained in:
2025-04-13 21:41:45 +08:00
parent 3df284ddd1
commit a4fe5ee09a
20 changed files with 1477 additions and 1544 deletions

View File

@@ -0,0 +1,305 @@
import cv2
import numpy as np
import dm_env
import collections
from collections import deque
import rospy
from sensor_msgs.msg import JointState
from sensor_msgs.msg import Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
from utils import display_camera_grid, save_data
class AlohaRobotRos:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.master_arm_right_deque = None
self.master_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.args = args
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.master_arm_left_deque = deque()
self.master_arm_right_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
def get_frame(self):
print(len(self.img_left_deque), len(self.img_right_deque), len(self.img_front_deque),
len(self.img_left_depth_deque), len(self.img_right_depth_deque), len(self.img_front_depth_deque))
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.master_arm_left_deque) == 0 or self.master_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.master_arm_right_deque) == 0 or self.master_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
# print("img_left:", img_left.shape)
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.master_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.master_arm_left_deque.popleft()
master_arm_left = self.master_arm_left_deque.popleft()
while self.master_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.master_arm_right_deque.popleft()
master_arm_right = self.master_arm_right_deque.popleft()
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_left_depth = cv2.copyMakeBorder(img_left_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_right_depth = cv2.copyMakeBorder(img_right_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
top, bottom, left, right = 40, 40, 0, 0
img_front_depth = cv2.copyMakeBorder(img_front_depth, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
# import pdb
# pdb.set_trace()
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def master_arm_left_callback(self, msg):
if len(self.master_arm_left_deque) >= 2000:
self.master_arm_left_deque.popleft()
self.master_arm_left_deque.append(msg)
def master_arm_right_callback(self, msg):
if len(self.master_arm_right_deque) >= 2000:
self.master_arm_right_deque.popleft()
self.master_arm_right_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def init_ros(self):
rospy.init_node('record_episodes', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.master_arm_left_topic, JointState, self.master_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.master_arm_right_topic, JointState, self.master_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
# rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
def process(self):
timesteps = []
actions = []
# 图像数据
image = np.random.randint(0, 255, size=(480, 640, 3), dtype=np.uint8)
image_dict = dict()
for cam_name in self.args.camera_names:
image_dict[cam_name] = image
count = 0
# input_key = input("please input s:")
# while input_key != 's' and not rospy.is_shutdown():
# input_key = input("please input s:")
rate = rospy.Rate(self.args.frame_rate)
print_flag = True
while (count < self.args.max_timesteps + 1) and not rospy.is_shutdown():
# 2 收集数据
result = self.get_frame()
# import pdb
# pdb.set_trace()
if not result:
if print_flag:
print("syn fail\n")
print_flag = False
rate.sleep()
continue
print_flag = True
count += 1
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, master_arm_left, master_arm_right, robot_base) = result
# 2.1 图像信息
image_dict = dict()
image_dict[self.args.camera_names[0]] = img_front
image_dict[self.args.camera_names[1]] = img_left
image_dict[self.args.camera_names[2]] = img_right
# import pdb
# pdb.set_trace()
display_camera_grid(image_dict)
# 2.2 从臂的信息从臂的状态 机械臂示教模式时 会自动订阅
obs = collections.OrderedDict() # 有序的字典
obs['images'] = image_dict
if self.args.use_depth_image:
image_dict_depth = dict()
image_dict_depth[self.args.camera_names[0]] = img_front_depth
image_dict_depth[self.args.camera_names[1]] = img_left_depth
image_dict_depth[self.args.camera_names[2]] = img_right_depth
obs['images_depth'] = image_dict_depth
obs['qpos'] = np.concatenate((np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
obs['qvel'] = np.concatenate((np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
obs['effort'] = np.concatenate((np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
if self.args.use_robot_base:
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
else:
obs['base_vel'] = [0.0, 0.0]
# 第一帧 只包含first fisrt只保存StepType.FIRST
if count == 1:
ts = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=None,
discount=None,
observation=obs)
timesteps.append(ts)
continue
# 时间步
ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=None,
discount=None,
observation=obs)
# 主臂保存状态
action = np.concatenate((np.array(master_arm_left.position), np.array(master_arm_right.position)), axis=0)
actions.append(action)
timesteps.append(ts)
print(f"\n{self.args.episode_idx} | Frame data: {count}\r", end="")
if rospy.is_shutdown():
exit(-1)
rate.sleep()
print("len(timesteps): ", len(timesteps))
print("len(actions) : ", len(actions))
return timesteps, actions