Files
tangger 3827c0e255
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
new lyt_aloha_real
2025-04-27 20:58:23 +08:00

206 lines
7.3 KiB
Python
Executable File

import dataclasses
import enum
import logging
import time
import numpy as np
from openpi_client import websocket_client_policy as _websocket_client_policy
import tyro
import rospy
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from agilex_utils import RosOperator
class EnvMode(enum.Enum):
"""Supported environments."""
ALOHA = "aloha"
ALOHA_SIM = "aloha_sim"
DROID = "droid"
LIBERO = "libero"
AGILEX_ALOHA = "agilex_arx_3camera_aloha"
@dataclasses.dataclass
class Args:
host: str = "172.20.103.171"
port: int = 8090
env: EnvMode = EnvMode.AGILEX_ALOHA
num_steps: int = 10
def main(args: Args) -> None:
obs_fn = {
EnvMode.ALOHA: _random_observation_aloha,
EnvMode.ALOHA_SIM: _random_observation_aloha,
EnvMode.DROID: _random_observation_droid,
EnvMode.LIBERO: _random_observation_libero,
EnvMode.AGILEX_ALOHA: observation_agilex_3camera_aloha,
}[args.env]
policy = _websocket_client_policy.WebsocketClientPolicy(
host=args.host,
port=args.port,
)
logging.info(f"Server metadata: {policy.get_server_metadata()}")
args_ros, ros_operator = init_agilex_3camera_aloha()
# Send 1 observation to make sure the model is loaded.
policy.infer(obs_fn(args_ros, ros_operator))
# test inference
start = time.time()
for _ in range(10):
policy.infer(obs_fn(args_ros, ros_operator))
end = time.time()
print(f"Total time taken: {end - start:.2f} s")
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
if 1000 * (end - start) / args.num_steps < 500:
logging.info("Inference time is less than 0.5 second! Its good!")
else:
logging.warning("Inference time is more than 0.5 second! Its bad!")
# pub
master_arm_left_publisher = rospy.Publisher(args_ros.master_arm_left_topic, JointState, queue_size=10)
master_arm_right_publisher = rospy.Publisher(args_ros.master_arm_right_topic, JointState, queue_size=10)
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
rate = rospy.Rate(30)
# 默认速度和力矩值
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
0.0, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
while True:
actions = policy.infer(obs_fn(args_ros, ros_operator))['actions']
for idx, action in enumerate(actions):
if(rospy.is_shutdown()):
break
# print(action)
print(idx, np.round(action[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = action[:7]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
# import pdb
# pdb.set_trace()
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = action[7:]
joint_state_msg.velocity = last_velocity[7:]
joint_state_msg.effort = last_effort[7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
def init_agilex_3camera_aloha():
import argparse
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.img_left_topic = '/camera_l/color/image_raw'
args.img_right_topic = '/camera_r/color/image_raw'
args.img_front_topic = '/camera_f/color/image_raw'
args.master_arm_left_topic = '/master/joint_left'
args.master_arm_right_topic = '/master/joint_right'
args.puppet_arm_left_topic = '/puppet/joint_left'
args.puppet_arm_right_topic = '/puppet/joint_right'
args.publish_rate = 30
args.use_robot_base = False
args.use_actions_interpolation = False
args.use_depth_image = False
ros_operator = RosOperator(args)
return args, ros_operator
def observation_agilex_3camera_aloha(args, ros_operator: RosOperator):
print_flag = True
rate = rospy.Rate(args.publish_rate)
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail")
print_flag = False
rate.sleep()
continue
print_flag = True
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base) = result
break
state = np.concatenate([
puppet_arm_left.position, puppet_arm_right.position
])
# a = np.random.randint(256, size=(3, 224, 224), dtype=np.uint8)
img_front = np.transpose(img_front, (2, 0, 1))
img_left = np.transpose(img_left, (2, 0, 1))
img_right = np.transpose(img_right, (2, 0, 1))
return {
"state": state,
"images": {
"cam_high": img_front,
"cam_left_wrist": img_left,
"cam_right_wrist": img_right,
},
"prompt": "weigh a reagent by a balance",
}
def _random_observation_aloha() -> dict:
return {
"state": np.ones((14,)),
"images": {
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
},
"prompt": "do something",
}
def _random_observation_droid() -> dict:
return {
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/joint_position": np.random.rand(7),
"observation/gripper_position": np.random.rand(1),
"prompt": "do something",
}
def _random_observation_libero() -> dict:
return {
"observation/state": np.random.rand(8),
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
"prompt": "do something",
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main(tyro.cli(Args))
# args, ros_operator = init_agilex_3camera_aloha()
# observation_agilex_3camera_aloha(args, ros_operator)
# print()