Files
lerobot_aloha/collect_data/replay_data.py
2025-04-05 21:46:49 +08:00

112 lines
5.2 KiB
Python

#coding=utf-8
import os
import numpy as np
import cv2
import h5py
import argparse
import rospy
from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from geometry_msgs.msg import Twist
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def main(args):
rospy.init_node("replay_node")
bridge = CvBridge()
# img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
# img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
# img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
# puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
# puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
# robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
# dataset_dir = args.dataset_dir
# episode_idx = args.episode_idx
# task_name = args.task_name
# dataset_name = f'episode_{episode_idx}'
dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode])
actions = dataset.hf_dataset.select_columns("action")
velocitys = dataset.hf_dataset.select_columns("observation.velocity")
efforts = dataset.hf_dataset.select_columns("observation.effort")
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称
twist_msg = Twist()
rate = rospy.Rate(args.fps)
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647]
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
rate = rospy.Rate(50)
for idx in range(len(actions)):
action = actions[idx]['action'].detach().cpu().numpy()
velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy()
effort = efforts[idx]['observation.effort'].detach().cpu().numpy()
if(rospy.is_shutdown()):
break
new_actions = np.linspace(last_action, action, 5) # 插值
new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值
new_efforts = np.linspace(last_effort, effort, 5) # 插值
last_action = action
last_velocity = velocity
last_effort = effort
for act in new_actions:
print(np.round(act[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = act[:7]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = act[7:]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
# default='/master/joint_left', required=False)
# parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
# default='/master/joint_right', required=False)
args = parser.parse_args()
args.repo_id = "tangger/test"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
args.episode = 1 # replay episode
args.master_arm_left_topic = "/master/joint_left"
args.master_arm_right_topic = "/master/joint_right"
args.fps = 30
main(args)
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0