206 lines
7.3 KiB
Python
Executable File
206 lines
7.3 KiB
Python
Executable File
import dataclasses
|
|
import enum
|
|
import logging
|
|
import time
|
|
|
|
import numpy as np
|
|
from openpi_client import websocket_client_policy as _websocket_client_policy
|
|
import tyro
|
|
import rospy
|
|
from std_msgs.msg import Header
|
|
from sensor_msgs.msg import Image, JointState
|
|
from agilex_utils import RosOperator
|
|
|
|
|
|
class EnvMode(enum.Enum):
|
|
"""Supported environments."""
|
|
|
|
ALOHA = "aloha"
|
|
ALOHA_SIM = "aloha_sim"
|
|
DROID = "droid"
|
|
LIBERO = "libero"
|
|
AGILEX_ALOHA = "agilex_arx_3camera_aloha"
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Args:
|
|
host: str = "172.20.103.171"
|
|
port: int = 8090
|
|
|
|
env: EnvMode = EnvMode.AGILEX_ALOHA
|
|
num_steps: int = 10
|
|
|
|
|
|
def main(args: Args) -> None:
|
|
obs_fn = {
|
|
EnvMode.ALOHA: _random_observation_aloha,
|
|
EnvMode.ALOHA_SIM: _random_observation_aloha,
|
|
EnvMode.DROID: _random_observation_droid,
|
|
EnvMode.LIBERO: _random_observation_libero,
|
|
EnvMode.AGILEX_ALOHA: observation_agilex_3camera_aloha,
|
|
}[args.env]
|
|
|
|
policy = _websocket_client_policy.WebsocketClientPolicy(
|
|
host=args.host,
|
|
port=args.port,
|
|
)
|
|
logging.info(f"Server metadata: {policy.get_server_metadata()}")
|
|
|
|
args_ros, ros_operator = init_agilex_3camera_aloha()
|
|
|
|
# Send 1 observation to make sure the model is loaded.
|
|
policy.infer(obs_fn(args_ros, ros_operator))
|
|
|
|
# test inference
|
|
start = time.time()
|
|
for _ in range(10):
|
|
policy.infer(obs_fn(args_ros, ros_operator))
|
|
end = time.time()
|
|
|
|
print(f"Total time taken: {end - start:.2f} s")
|
|
print(f"Average inference time: {1000 * (end - start) / args.num_steps:.2f} ms")
|
|
if 1000 * (end - start) / args.num_steps < 500:
|
|
logging.info("Inference time is less than 0.5 second! Its good!")
|
|
else:
|
|
logging.warning("Inference time is more than 0.5 second! Its bad!")
|
|
|
|
|
|
# pub
|
|
master_arm_left_publisher = rospy.Publisher(args_ros.master_arm_left_topic, JointState, queue_size=10)
|
|
master_arm_right_publisher = rospy.Publisher(args_ros.master_arm_right_topic, JointState, queue_size=10)
|
|
joint_state_msg = JointState()
|
|
joint_state_msg.header = Header()
|
|
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
|
|
rate = rospy.Rate(30)
|
|
# 默认速度和力矩值
|
|
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
|
|
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
|
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
|
|
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
|
-0.03296661376953125, -0.03296661376953125]
|
|
|
|
last_effort = [-0.021978378295898438, 0.2417583465576172, 0.320878982543945,
|
|
0.6527481079101562, -0.013187408447265625, -0.013187408447265625,
|
|
0.0, -0.010990142822265625, -0.010990142822265625,
|
|
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
|
|
-0.03296661376953125, -0.03296661376953125]
|
|
while True:
|
|
actions = policy.infer(obs_fn(args_ros, ros_operator))['actions']
|
|
for idx, action in enumerate(actions):
|
|
if(rospy.is_shutdown()):
|
|
break
|
|
# print(action)
|
|
print(idx, np.round(action[:7], 4))
|
|
cur_timestamp = rospy.Time.now() # 设置时间戳
|
|
joint_state_msg.header.stamp = cur_timestamp
|
|
|
|
joint_state_msg.position = action[:7]
|
|
joint_state_msg.velocity = last_velocity[:7]
|
|
joint_state_msg.effort = last_effort[:7]
|
|
# import pdb
|
|
# pdb.set_trace()
|
|
master_arm_left_publisher.publish(joint_state_msg)
|
|
|
|
joint_state_msg.position = action[7:]
|
|
joint_state_msg.velocity = last_velocity[7:]
|
|
joint_state_msg.effort = last_effort[7:]
|
|
master_arm_right_publisher.publish(joint_state_msg)
|
|
if(rospy.is_shutdown()):
|
|
break
|
|
rate.sleep()
|
|
|
|
|
|
|
|
def init_agilex_3camera_aloha():
|
|
import argparse
|
|
parser = argparse.ArgumentParser()
|
|
args = parser.parse_args()
|
|
args.img_left_topic = '/camera_l/color/image_raw'
|
|
args.img_right_topic = '/camera_r/color/image_raw'
|
|
args.img_front_topic = '/camera_f/color/image_raw'
|
|
|
|
args.master_arm_left_topic = '/master/joint_left'
|
|
args.master_arm_right_topic = '/master/joint_right'
|
|
args.puppet_arm_left_topic = '/puppet/joint_left'
|
|
args.puppet_arm_right_topic = '/puppet/joint_right'
|
|
|
|
args.publish_rate = 30
|
|
args.use_robot_base = False
|
|
args.use_actions_interpolation = False
|
|
args.use_depth_image = False
|
|
|
|
ros_operator = RosOperator(args)
|
|
return args, ros_operator
|
|
|
|
def observation_agilex_3camera_aloha(args, ros_operator: RosOperator):
|
|
print_flag = True
|
|
rate = rospy.Rate(args.publish_rate)
|
|
while True and not rospy.is_shutdown():
|
|
result = ros_operator.get_frame()
|
|
if not result:
|
|
if print_flag:
|
|
print("syn fail")
|
|
print_flag = False
|
|
rate.sleep()
|
|
continue
|
|
print_flag = True
|
|
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
|
|
puppet_arm_left, puppet_arm_right, robot_base) = result
|
|
break
|
|
|
|
state = np.concatenate([
|
|
puppet_arm_left.position, puppet_arm_right.position
|
|
])
|
|
# a = np.random.randint(256, size=(3, 224, 224), dtype=np.uint8)
|
|
img_front = np.transpose(img_front, (2, 0, 1))
|
|
img_left = np.transpose(img_left, (2, 0, 1))
|
|
img_right = np.transpose(img_right, (2, 0, 1))
|
|
return {
|
|
"state": state,
|
|
"images": {
|
|
"cam_high": img_front,
|
|
"cam_left_wrist": img_left,
|
|
"cam_right_wrist": img_right,
|
|
},
|
|
"prompt": "weigh a reagent by a balance",
|
|
}
|
|
|
|
|
|
def _random_observation_aloha() -> dict:
|
|
return {
|
|
"state": np.ones((14,)),
|
|
"images": {
|
|
"cam_high": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
"cam_low": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
"cam_left_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
"cam_right_wrist": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
|
|
},
|
|
"prompt": "do something",
|
|
}
|
|
|
|
|
|
def _random_observation_droid() -> dict:
|
|
return {
|
|
"observation/exterior_image_1_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"observation/wrist_image_left": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"observation/joint_position": np.random.rand(7),
|
|
"observation/gripper_position": np.random.rand(1),
|
|
"prompt": "do something",
|
|
}
|
|
|
|
|
|
def _random_observation_libero() -> dict:
|
|
return {
|
|
"observation/state": np.random.rand(8),
|
|
"observation/image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"observation/wrist_image": np.random.randint(256, size=(224, 224, 3), dtype=np.uint8),
|
|
"prompt": "do something",
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main(tyro.cli(Args))
|
|
# args, ros_operator = init_agilex_3camera_aloha()
|
|
# observation_agilex_3camera_aloha(args, ros_operator)
|
|
# print() |