457 lines
19 KiB
Python
457 lines
19 KiB
Python
import yaml
|
||
import cv2
|
||
import numpy as np
|
||
import collections
|
||
import dm_env
|
||
import argparse
|
||
from typing import Dict, List, Any, Optional
|
||
from collections import deque
|
||
import rospy
|
||
from cv_bridge import CvBridge
|
||
from std_msgs.msg import Header
|
||
from sensor_msgs.msg import Image, JointState
|
||
from nav_msgs.msg import Odometry
|
||
from rosrobot import Robot
|
||
import torch
|
||
import time
|
||
|
||
|
||
class AgilexRobot(Robot):
|
||
def get_frame(self) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取同步帧数据
|
||
返回: 包含同步数据的字典,或None如果同步失败
|
||
"""
|
||
# 检查基本数据可用性
|
||
# print(self.sync_img_queues.values())
|
||
if any(len(q) == 0 for q in self.sync_img_queues.values()):
|
||
print("camera has not get image data")
|
||
return None
|
||
|
||
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
|
||
return None
|
||
|
||
# print(self.sync_arm_queues.values())
|
||
# if any(len(q) == 0 for q in self.sync_arm_queues.values()):
|
||
# print("2")
|
||
# if len(self.sync_arm_queues['master_left']) == 0:
|
||
# print("can not get data from master topic")
|
||
# if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
|
||
# print("can not get data from puppet topic")
|
||
# return None
|
||
|
||
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
|
||
print("can not get data from puppet topic")
|
||
return None
|
||
|
||
# 计算最小时间戳
|
||
timestamps = [
|
||
q[-1].header.stamp.to_sec()
|
||
for q in list(self.sync_img_queues.values()) +
|
||
list(self.sync_arm_queues.values())
|
||
]
|
||
|
||
if self.use_depth_image:
|
||
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
|
||
|
||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
|
||
|
||
min_time = min(timestamps)
|
||
|
||
# 检查数据同步性
|
||
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
||
if queue[-1].header.stamp.to_sec() < min_time:
|
||
return None
|
||
|
||
if self.use_depth_image:
|
||
for queue in self.sync_depth_queues.values():
|
||
if queue[-1].header.stamp.to_sec() < min_time:
|
||
return None
|
||
|
||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
|
||
return None
|
||
|
||
# 提取同步数据
|
||
frame_data = {
|
||
'images': {},
|
||
'arms': {},
|
||
'timestamp': min_time
|
||
}
|
||
|
||
# 图像数据
|
||
for cam_name, queue in self.sync_img_queues.items():
|
||
while queue[0].header.stamp.to_sec() < min_time:
|
||
queue.popleft()
|
||
frame_data['images'][cam_name] = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||
|
||
# 深度数据
|
||
if self.use_depth_image:
|
||
frame_data['depths'] = {}
|
||
for cam_name, queue in self.sync_depth_queues.items():
|
||
while queue[0].header.stamp.to_sec() < min_time:
|
||
queue.popleft()
|
||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||
# 保持原有的边界填充
|
||
frame_data['depths'][cam_name] = cv2.copyMakeBorder(
|
||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||
)
|
||
|
||
# 机械臂数据
|
||
for arm_name, queue in self.sync_arm_queues.items():
|
||
while queue[0].header.stamp.to_sec() < min_time:
|
||
queue.popleft()
|
||
frame_data['arms'][arm_name] = queue.popleft()
|
||
|
||
# 基座数据
|
||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
|
||
self.sync_base_queue.popleft()
|
||
frame_data['base'] = self.sync_base_queue.popleft()
|
||
|
||
return frame_data
|
||
|
||
|
||
def teleop_step(self) -> Optional[tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]]:
|
||
"""
|
||
获取同步帧数据,输出格式与 teleop_step 一致
|
||
返回: (obs_dict, action_dict) 或 None 如果同步失败
|
||
"""
|
||
# 检查基本数据可用性
|
||
if any(len(q) == 0 for q in self.sync_img_queues.values()):
|
||
return None, None
|
||
|
||
if self.use_depth_image and any(len(q) == 0 for q in self.sync_depth_queues.values()):
|
||
return None, None
|
||
|
||
if any(len(q) == 0 for q in self.sync_arm_queues.values()):
|
||
return None, None
|
||
|
||
# 计算最小时间戳
|
||
timestamps = [
|
||
q[-1].header.stamp.to_sec()
|
||
for q in list(self.sync_img_queues.values()) +
|
||
list(self.sync_arm_queues.values())
|
||
]
|
||
|
||
if self.use_depth_image:
|
||
timestamps.extend(q[-1].header.stamp.to_sec() for q in self.sync_depth_queues.values())
|
||
|
||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||
timestamps.append(self.sync_base_queue[-1].header.stamp.to_sec())
|
||
|
||
min_time = min(timestamps)
|
||
|
||
# 检查数据同步性
|
||
for queue in list(self.sync_img_queues.values()) + list(self.sync_arm_queues.values()):
|
||
if queue[-1].header.stamp.to_sec() < min_time:
|
||
return None, None
|
||
|
||
if self.use_depth_image:
|
||
for queue in self.sync_depth_queues.values():
|
||
if queue[-1].header.stamp.to_sec() < min_time:
|
||
return None, None
|
||
|
||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||
if self.sync_base_queue[-1].header.stamp.to_sec() < min_time:
|
||
return None, None
|
||
|
||
# 初始化输出字典
|
||
obs_dict = {}
|
||
action_dict = {}
|
||
|
||
# 处理图像数据
|
||
for cam_name, queue in self.sync_img_queues.items():
|
||
while queue[0].header.stamp.to_sec() < min_time:
|
||
queue.popleft()
|
||
img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||
obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(img)
|
||
|
||
# 处理深度数据
|
||
if self.use_depth_image:
|
||
for cam_name, queue in self.sync_depth_queues.items():
|
||
while queue[0].header.stamp.to_sec() < min_time:
|
||
queue.popleft()
|
||
depth_img = self.bridge.imgmsg_to_cv2(queue.popleft(), 'passthrough')
|
||
depth_img = cv2.copyMakeBorder(
|
||
depth_img, 40, 40, 0, 0, cv2.BORDER_CONSTANT, value=0
|
||
)
|
||
obs_dict[f"observation.images.depth_{cam_name}"] = torch.from_numpy(depth_img).unsqueeze(-1)
|
||
|
||
# 处理机械臂观测数据
|
||
arm_states = []
|
||
arm_velocity = []
|
||
arm_effort = []
|
||
actions = []
|
||
for arm_name, queue in self.sync_arm_queues.items():
|
||
while queue[0].header.stamp.to_sec() < min_time:
|
||
queue.popleft()
|
||
arm_data = queue.popleft()
|
||
|
||
# np.array(arm_data.position),
|
||
# np.array(arm_data.velocity),
|
||
# np.array(arm_data.effort)
|
||
# 如果是从臂(puppet),作为观测
|
||
if arm_name.startswith('puppet'):
|
||
arm_states.append(np.array(arm_data.position, dtype=np.float32))
|
||
arm_velocity.append(np.array(arm_data.velocity, dtype=np.float32))
|
||
arm_effort.append(np.array(arm_data.effort, dtype=np.float32))
|
||
|
||
# 如果是主臂(master),作为动作
|
||
if arm_name.startswith('master'):
|
||
# action_dict[f"action.{arm_name}"] = torch.from_numpy(np.array(arm_data.position))
|
||
actions.append(np.array(arm_data.position, dtype=np.float32))
|
||
|
||
if arm_states:
|
||
obs_dict["observation.state"] = torch.tensor(np.concatenate(arm_states).reshape(-1)) # 先转Python列表
|
||
|
||
if arm_velocity:
|
||
obs_dict["observation.velocity"] = torch.tensor(np.concatenate(arm_velocity).reshape(-1))
|
||
|
||
if arm_effort:
|
||
obs_dict["observation.effort"] = torch.tensor(np.concatenate(arm_effort).reshape(-1))
|
||
|
||
if actions:
|
||
action_dict["action"] = torch.tensor(np.concatenate(actions).reshape(-1))
|
||
# action_dict["action"] = np.concatenate(actions).squeeze()
|
||
|
||
# 处理基座数据
|
||
if self.use_robot_base and len(self.sync_base_queue) > 0:
|
||
while self.sync_base_queue[0].header.stamp.to_sec() < min_time:
|
||
self.sync_base_queue.popleft()
|
||
base_data = self.sync_base_queue.popleft()
|
||
obs_dict["observation.base_vel"] = torch.tensor([
|
||
base_data.twist.twist.linear.x,
|
||
base_data.twist.twist.angular.z
|
||
], dtype=torch.float32)
|
||
|
||
# 添加时间戳
|
||
# obs_dict["observation.timestamp"] = torch.tensor(min_time, dtype=torch.float64)
|
||
|
||
return obs_dict, action_dict
|
||
|
||
|
||
def capture_observation(self):
|
||
"""Capture observation data from ROS topics without batch dimension.
|
||
|
||
Returns:
|
||
dict: Observation dictionary containing state and images.
|
||
|
||
Raises:
|
||
RobotDeviceNotConnectedError: If robot is not connected.
|
||
"""
|
||
# Initialize observation dictionary
|
||
obs_dict = {}
|
||
|
||
# Get synchronized frame data
|
||
frame_data = self.get_frame()
|
||
if frame_data is None:
|
||
# raise RuntimeError("Failed to capture synchronized observation data")
|
||
return None
|
||
|
||
# Process arm state data (from puppet arms)
|
||
arm_states = []
|
||
arm_velocity = []
|
||
arm_effort = []
|
||
|
||
for arm_name, joint_state in frame_data['arms'].items():
|
||
if arm_name.startswith('puppet'):
|
||
# Record timing for performance monitoring
|
||
before_read_t = time.perf_counter()
|
||
|
||
# Get position data and convert to tensor
|
||
pos = torch.from_numpy(np.array(joint_state.position, dtype=np.float32))
|
||
arm_states.append(pos)
|
||
|
||
velocity = torch.from_numpy(np.array(joint_state.velocity, dtype=np.float32))
|
||
arm_velocity.append(velocity)
|
||
|
||
effort = torch.from_numpy(np.array(joint_state.effort, dtype=np.float32))
|
||
arm_effort.append(effort)
|
||
|
||
# Log timing information
|
||
# self.logs[f"read_arm_{arm_name}_pos_dt_s"] = time.perf_counter() - before_read_t
|
||
print(f"read_arm_{arm_name}_pos_dt_s is", time.perf_counter() - before_read_t)
|
||
|
||
# Combine all arm states into single tensor
|
||
if arm_states:
|
||
obs_dict["observation.state"] = torch.cat(arm_states)
|
||
|
||
if arm_velocity:
|
||
obs_dict["observation.velocity"] = torch.cat(arm_velocity)
|
||
|
||
if arm_effort:
|
||
obs_dict["observation.effort"] = torch.cat(arm_effort)
|
||
|
||
# Process image data
|
||
for cam_name, img in frame_data['images'].items():
|
||
# Record timing for performance monitoring
|
||
before_camread_t = time.perf_counter()
|
||
|
||
# Convert image to tensor
|
||
img_tensor = torch.from_numpy(img)
|
||
obs_dict[f"observation.images.{cam_name}"] = img_tensor
|
||
|
||
# Log timing information
|
||
# self.logs[f"read_camera_{cam_name}_dt_s"] = time.perf_counter() - before_camread_t
|
||
print(f"read_camera_{cam_name}_dt_s is", time.perf_counter() - before_camread_t)
|
||
|
||
# Process depth data if enabled
|
||
if self.use_depth_image and 'depths' in frame_data:
|
||
for cam_name, depth_img in frame_data['depths'].items():
|
||
before_depthread_t = time.perf_counter()
|
||
|
||
# Convert depth image to tensor and add channel dimension
|
||
depth_tensor = torch.from_numpy(depth_img).unsqueeze(-1)
|
||
obs_dict[f"observation.images.depth_{cam_name}"] = depth_tensor
|
||
|
||
# self.logs[f"read_depth_{cam_name}_dt_s"] = time.perf_counter() - before_depthread_t
|
||
print(f"read_depth_{cam_name}_dt_s is", time.perf_counter() - before_depthread_t)
|
||
|
||
# Process base velocity if enabled
|
||
if self.use_robot_base and 'base' in frame_data:
|
||
base_data = frame_data['base']
|
||
obs_dict["observation.base_vel"] = torch.tensor([
|
||
base_data.twist.twist.linear.x,
|
||
base_data.twist.twist.angular.z
|
||
], dtype=torch.float32)
|
||
|
||
return obs_dict
|
||
|
||
def send_action(self, action: torch.Tensor) -> torch.Tensor:
|
||
"""
|
||
Send joint position commands to the puppet arms via ROS.
|
||
|
||
Args:
|
||
action: Tensor containing concatenated goal positions for all puppet arms
|
||
Shape should match the action space defined in features["action"]
|
||
|
||
Returns:
|
||
The actual action that was sent (may be clipped if safety checks are implemented)
|
||
"""
|
||
# if not hasattr(self, 'puppet_arm_publishers'):
|
||
# # Initialize publishers on first call
|
||
# self._init_action_publishers()
|
||
|
||
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
||
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
|
||
|
||
# Convert tensor to numpy array if needed
|
||
if isinstance(action, torch.Tensor):
|
||
action = action.detach().cpu().numpy()
|
||
|
||
# Split action into individual arm commands based on config
|
||
from_idx = 0
|
||
to_idx = 0
|
||
action_sent = []
|
||
for arm_name, arm_config in self.arms.items():
|
||
# 主臂topic是否存在
|
||
if not "master" in arm_name:
|
||
continue
|
||
|
||
# Get number of joints for this arm
|
||
num_joints = len(arm_config.get('motors', []))
|
||
to_idx += num_joints
|
||
|
||
# Extract this arm's portion of the action
|
||
arm_action = action[from_idx:to_idx]
|
||
arm_velocity = last_velocity[from_idx:to_idx]
|
||
arm_effort = last_effort[from_idx:to_idx]
|
||
from_idx = to_idx
|
||
|
||
# Apply safety checks if configured
|
||
if 'max_relative_target' in self.config:
|
||
# Get current position from the queue
|
||
if len(self.sync_arm_queues[arm_name]) > 0:
|
||
current_state = self.sync_arm_queues[arm_name][-1]
|
||
current_pos = np.array(current_state.position)
|
||
|
||
# Clip the action to stay within max relative target
|
||
max_delta = self.config['max_relative_target']
|
||
clipped_action = np.clip(arm_action,
|
||
current_pos - max_delta,
|
||
current_pos + max_delta)
|
||
arm_action = clipped_action
|
||
|
||
action_sent.append(arm_action)
|
||
|
||
# Create and publish JointState message
|
||
joint_state = JointState()
|
||
joint_state.header = Header()
|
||
joint_state.header.stamp = rospy.Time.now()
|
||
joint_state.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', '']
|
||
joint_state.position = arm_action.tolist()
|
||
joint_state.velocity = arm_velocity
|
||
joint_state.effort = arm_effort
|
||
|
||
# Publish to the corresponding topic
|
||
self.publishers[f"arm_{arm_name}"].publish(joint_state)
|
||
|
||
return torch.from_numpy(np.concatenate(action_sent)) if action_sent else torch.tensor([])
|
||
|
||
# def _init_action_publishers(self) -> None:
|
||
# """Initialize ROS publishers for puppet arms"""
|
||
# self.puppet_arm_publishers = {}
|
||
# # rospy.init_node("replay_node")
|
||
# for arm_name, arm_config in self.arms.items():
|
||
# if not "puppet" in arm_name:
|
||
# # if not "master" in arm_name:
|
||
# continue
|
||
|
||
# if 'topic_name' not in arm_config:
|
||
# rospy.logwarn(f"No puppet topic defined for arm {arm_name}")
|
||
# continue
|
||
|
||
# self.puppet_arm_publishers[arm_name] = rospy.Publisher(
|
||
# arm_config['topic_name'],
|
||
# JointState,
|
||
# queue_size=10
|
||
# )
|
||
|
||
# # Wait for publisher to connect
|
||
# rospy.sleep(0.1)
|
||
|
||
# rospy.loginfo("Initialized puppet arm publishers")
|
||
|
||
|
||
|
||
|
||
# def get_arguments() -> argparse.Namespace:
|
||
# """获取运行时参数"""
|
||
# parser = argparse.ArgumentParser()
|
||
# parser.add_argument('--fps', type=int, help='Frame rate', default=30)
|
||
# parser.add_argument('--max_timesteps', type=int, help='Max timesteps', default=500)
|
||
# parser.add_argument('--episode_idx', type=int, help='Episode index', default=0)
|
||
# parser.add_argument('--use_depth', action='store_true', help='Use depth images')
|
||
# parser.add_argument('--use_base', action='store_true', help='Use robot base')
|
||
# return parser.parse_args()
|
||
|
||
|
||
# if __name__ == "__main__":
|
||
# # 示例用法
|
||
# import json
|
||
# args = get_arguments()
|
||
# robot = AgilexRobot(config_file="/home/jgl20/LYT/work/collect_data_lerobot/1.yaml", args=args)
|
||
# print(json.dumps(robot.features, indent=4))
|
||
# robot.warmup_record()
|
||
# count = 0
|
||
# print_flag = True
|
||
# rate = rospy.Rate(args.fps)
|
||
# while (count < args.max_timesteps + 1) and not rospy.is_shutdown():
|
||
# a, b = robot.teleop_step()
|
||
|
||
# if a is None or b is None:
|
||
# if print_flag:
|
||
# print("syn fail\n")
|
||
# print_flag = False
|
||
# rate.sleep()
|
||
# continue
|
||
# else:
|
||
# print(a)
|
||
|
||
|
||
# timesteps, actions = robot.process()
|
||
# print(timesteps)
|
||
print()
|