import dataclasses import enum import logging import time import numpy as np from openpi_client import websocket_client_policy as _websocket_client_policy import tyro import rospy from std_msgs.msg import Header from sensor_msgs.msg import Image, JointState from agilex_utils import RosOperator class EnvMode(enum.Enum): """Supported environments.""" ALOHA = "aloha" ALOHA_SIM = "aloha_sim" DROID = "droid" LIBERO = "libero" AGILEX_ALOHA = "agilex_arx_3camera_aloha" @dataclasses.dataclass class Args: host: str = "172.20.103.171" port: int = 8090 env: EnvMode = EnvMode.AGILEX_ALOHA num_steps: int = 10 def main(args: Args) -> None: obs_fn = { EnvMode.ALOHA: _random_observation_aloha, EnvMode.ALOHA_SIM: _random_observation_aloha, EnvMode.DROID: _random_observation_droid, EnvMode.LIBERO: _random_observation_libero, EnvMode.AGILEX_ALOHA: observation_agilex_3camera_aloha, }[args.env] policy = _websocket_client_policy.WebsocketClientPolicy( host=args.host, port=args.port, ) logging.info(f"Server metadata: {policy.get_server_metadata()}") args_ros, ros_operator = init_agilex_3camera_aloha() # Send 1 observation to make sure the model is loaded. policy.infer(obs_fn(args_ros, ros_operator)) # test inference start = time.time() for _ in range(10): policy.infer(obs_fn(args_ros, ros_operator)) end = time.time() print(f"Total time taken: {end - start:.2f} s") print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms") if 1000 * (end - start) / args.num_steps < 500: logging.info("Inference time is less than 0.5 second! Its good!") else: logging.warning("Inference time is more than 0.5 second! Its bad!") # pub master_arm_left_publisher = rospy.Publisher(args_ros.master_arm_left_topic, JointState, queue_size=10) master_arm_right_publisher = rospy.Publisher(args_ros.master_arm_right_topic, JointState, queue_size=10) joint_state_msg = JointState() joint_state_msg.header = Header() joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 rate = rospy.Rate(30) # 默认速度和力矩值 last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945, 0.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] while True: actions = policy.infer(obs_fn(args_ros, ros_operator))['actions'] for idx, action in enumerate(actions): if(rospy.is_shutdown()): break # print(action) print(idx, np.round(action[:7], 4)) cur_timestamp = rospy.Time.now() # 设置时间戳 joint_state_msg.header.stamp = cur_timestamp joint_state_msg.position = action[:7] joint_state_msg.velocity = last_velocity[:7] joint_state_msg.effort = last_effort[:7] # import pdb # pdb.set_trace() master_arm_left_publisher.publish(joint_state_msg) joint_state_msg.position = action[7:] joint_state_msg.velocity = last_velocity[7:] joint_state_msg.effort = last_effort[7:] master_arm_right_publisher.publish(joint_state_msg) if(rospy.is_shutdown()): break rate.sleep() def init_agilex_3camera_aloha(): import argparse parser = argparse.ArgumentParser() args = parser.parse_args() args.img_left_topic = '/camera_l/color/image_raw' args.img_right_topic = '/camera_r/color/image_raw' args.img_front_topic = '/camera_f/color/image_raw' args.master_arm_left_topic = '/master/joint_left' args.master_arm_right_topic = '/master/joint_right' args.puppet_arm_left_topic = '/puppet/joint_left' args.puppet_arm_right_topic = '/puppet/joint_right' args.publish_rate = 30 args.use_robot_base = False args.use_actions_interpolation = False args.use_depth_image = False ros_operator = RosOperator(args) return args, ros_operator def observation_agilex_3camera_aloha(args, ros_operator: RosOperator): print_flag = True rate = rospy.Rate(args.publish_rate) while True and not rospy.is_shutdown(): result = ros_operator.get_frame() if not result: if print_flag: print("syn fail") print_flag = False rate.sleep() continue print_flag = True (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, puppet_arm_left, puppet_arm_right, robot_base) = result break state = np.concatenate([ puppet_arm_left.position, puppet_arm_right.position ]) # a = np.random.randint(256, size=(3, 224, 224), dtype=np.uint8) img_front = np.transpose(img_front, (2, 0, 1)) img_left = np.transpose(img_left, (2, 0, 1)) img_right = np.transpose(img_right, (2, 0, 1)) return { "state": state, "images": { "cam_high": img_front, "cam_left_wrist": img_left, "cam_right_wrist": img_right, }, "prompt": "weigh a reagent by a balance", } def _random_observation_aloha() -> dict: return { "state": np.ones((14,)), "images": { "cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), "cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8), }, "prompt": "do something", } def _random_observation_droid() -> dict: return { "observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/joint_position": np.random.rand(7), "observation/gripper_position": np.random.rand(1), "prompt": "do something", } def _random_observation_libero() -> dict: return { "observation/state": np.random.rand(8), "observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8), "prompt": "do something", } if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main(tyro.cli(Args)) # args, ros_operator = init_agilex_3camera_aloha() # observation_agilex_3camera_aloha(args, ros_operator) # print()