modify code structure

This commit is contained in:
2025-04-07 19:45:34 +08:00
parent 91c2b7b0cb
commit d843a990a3
25 changed files with 2135 additions and 333 deletions

3
.gitignore vendored
View File

@@ -1,2 +1,3 @@
cobot_magic/
librealsense/
librealsense/
data*/

View File

@@ -19,6 +19,12 @@ cameras:
rgb_shape: [480, 640, 3]
width: 480
height: 640
cam_high:
img_topic_name: /camera/color/image_raw
depth_topic_name: /camera/depth/image_rect_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
arm:
master_left:

View File

@@ -408,7 +408,7 @@ def get_arguments():
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
args.root = "./data5"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
@@ -429,51 +429,26 @@ def get_arguments():
# @parser.wrap()
# def control_robot(cfg: ControlPipelineConfig):
# init_logging()
# logging.info(pformat(asdict(cfg)))
# # robot = make_robot_from_config(cfg.robot)
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# if isinstance(cfg.control, RecordControlConfig):
# print(cfg.control)
# record(robot, cfg.control)
# elif isinstance(cfg.control, ReplayControlConfig):
# replay(robot, cfg.control)
# # if robot.is_connected:
# # # Disconnect manually to avoid a "Core dump" during process
# # # termination due to camera threads not properly exiting.
# # robot.disconnect()
# @parser.wrap()
def control_robot(cfg):
# robot = make_robot_from_config(cfg.robot)
from agilex_robot import AgilexRobot
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
from rosrobot_factory import RobotFactory
# 使用工厂模式创建机器人实例
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
# if robot.is_connected:
# # Disconnect manually to avoid a "Core dump" during process
# # termination due to camera threads not properly exiting.
# robot.disconnect()
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# cfg = get_arguments()
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# 使用工厂模式创建机器人实例
# robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)

View File

@@ -1,26 +0,0 @@
import yaml
import argparse
from typing import Dict, List, Any, Optional
from rosrobot import Robot
from agilex_robot import AgilexRobot
class RobotFactory:
@staticmethod
def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
"""
根据配置文件自动创建合适的机器人实例
Args:
config_file: 配置文件路径
args: 运行时参数
"""
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
robot_type = config.get('robot_type', 'agilex')
if robot_type == 'agilex':
return AgilexRobot(config_file, args)
# 可扩展其他机器人类型
else:
raise ValueError(f"Unsupported robot type: {robot_type}")

2
init_robot.bash Normal file
View File

@@ -0,0 +1,2 @@
source ~/ros_noetic/devel_isolated/setup.bash
cd cobot_magic/remote_control-x86-can-v2 && ./tools/can.sh && ./tools/jgl_2follower.sh

1
lerobot Submodule

Submodule lerobot added at 1c873df5c0

3
lerobot_aloha/README.MD Normal file
View File

@@ -0,0 +1,3 @@
python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data
python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1

View File

@@ -0,0 +1,461 @@
import logging
import time
from dataclasses import asdict
from pprint import pformat
from pprint import pprint
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.common.robot_devices.control_utils import (
# init_keyboard_listener,
record_episode,
stop_recording,
is_headless
)
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.common.utils.utils import get_safe_torch_device
from contextlib import nullcontext
from copy import copy
import torch
import rospy
import cv2
from lerobot.configs import parser
from common.agilex_robot import AgilexRobot
from common.rosrobot_factory import RobotFactory
########################################################################################
# Control modes
########################################################################################
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
# if not robot.is_connected:
# robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
rate = rospy.Rate(fps)
print_flag = True
while timestamp < control_time_s and not rospy.is_shutdown():
# print(timestamp < control_time_s)
# print(rospy.is_shutdown())
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step()
if observation is None or action is None:
if print_flag:
print("sync data fail, retrying...\n")
print_flag = False
rate.sleep()
continue
else:
# pass
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# if display_cameras and not is_headless():
# image_keys = [key for key in observation if "image" in key]
# for key in image_keys:
# if "depth" in key:
# pass
# else:
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
# print(1)
# cv2.waitKey(1)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080可以根据实际情况调整
screen_width = 1920
screen_height = 1080
# 计算窗口的排列方式
num_images = len(image_keys)
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
columns = min(num_images, max_columns) # 实际使用的列数
# 遍历所有图像键并显示
for idx, key in enumerate(image_keys):
if "depth" in key:
continue # 跳过深度图像
# 将图像从 RGB 转换为 BGR 格式
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
# 创建窗口
cv2.imshow(key, image)
# 计算窗口位置
window_width = 640
window_height = 480
row = idx // max_columns
col = idx % max_columns
x_position = col * window_width
y_position = row * window_height
# 移动窗口到指定位置
cv2.moveWindow(key, x_position, y_position)
# 等待 1 毫秒以处理事件
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["record_start"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
events["record_start"] = False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.up:
print("Up arrow pressed. Start data recording...")
events["record_start"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def stop_recording(robot, listener, display_cameras):
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
fps,
single_task,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
def record(
robot,
cfg
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
# Create empty dataset or load existing saved episodes
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# policy = None
# if not robot.is_connected:
# robot.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
# if has_method(robot, "teleop_safety_stop"):
# robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
# if events["record_start"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset
def replay(
robot: AgilexRobot,
cfg,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
# if not robot.is_connected:
# robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
# log_control_info(robot, dt_s, fps=cfg.fps)
import argparse
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "./data5"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 100
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "move the bottle from the right to the scale right"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy = None
# args.teleoprate = True
args.control_type = "record"
# args.control_type = "replay"
return args
# @parser.wrap()
def control_robot(cfg):
# 使用工厂模式创建机器人实例
robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# 使用工厂模式创建机器人实例
# robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)
# capture = robot.capture_observation()
# import torch
# torch.save(capture, "test.pt")
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
# device='cpu')
# robot.send_action(action.squeeze(0))
# print()

View File

@@ -1,17 +1,13 @@
import yaml
import cv2
import numpy as np
import collections
import dm_env
import argparse
from typing import Dict, List, Any, Optional
from collections import deque
import rospy
from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from nav_msgs.msg import Odometry
from rosrobot import Robot
from sensor_msgs.msg import JointState
from .rosrobot import Robot
import torch
import time
@@ -40,9 +36,12 @@ class AgilexRobot(Robot):
# print("can not get data from puppet topic")
# return None
if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0:
print("can not get data from puppet topic")
return None
# 检查必要的机械臂数据是否可用
required_arms = ['puppet_left', 'puppet_right']
for arm_name in required_arms:
if arm_name not in self.sync_arm_queues or len(self.sync_arm_queues[arm_name]) == 0:
print(f"can not get data from {arm_name} topic")
return None
# 计算最小时间戳
timestamps = [
@@ -330,12 +329,18 @@ class AgilexRobot(Robot):
Returns:
The actual action that was sent (may be clipped if safety checks are implemented)
"""
# if not hasattr(self, 'puppet_arm_publishers'):
# # Initialize publishers on first call
# self._init_action_publishers()
# 默认速度和力矩值
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125,
0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.010990142822265625, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945,
3.6527481079101562, -0.013187408447265625, -0.013187408447265625,
0.0, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.010990142822265625, -0.010990142822265625,
-0.03296661376953125, -0.03296661376953125]
# Convert tensor to numpy array if needed
if isinstance(action, torch.Tensor):

View File

@@ -8,29 +8,38 @@ from nav_msgs.msg import Odometry
import argparse
class Robot:
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
class RobotConfig:
"""Configuration management for robot components"""
def __init__(self, config_file: str):
"""
机器人基类处理通用初始化逻辑
Initialize robot configuration from YAML file
Args:
config_file: YAML配置文件路径
args: 运行时参数
config_file: Path to YAML configuration file
"""
self._load_config(config_file)
self._merge_runtime_args(args)
self._init_components()
self._init_data_structures()
self.init_ros()
self.init_features()
self.warmup()
def _load_config(self, config_file: str) -> None:
"""加载YAML配置文件"""
self.config = self._load_yaml(config_file)
self._validate_config()
def _load_yaml(self, config_file: str) -> Dict[str, Any]:
"""Load configuration from YAML file"""
with open(config_file, 'r') as f:
self.config = yaml.safe_load(f)
def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
"""合并运行时参数到配置"""
return yaml.safe_load(f)
def _validate_config(self) -> None:
"""Validate configuration completeness"""
required_sections = ['cameras', 'arm']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required config section: {section}")
def merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None:
"""
Merge runtime arguments into configuration
Args:
args: Runtime arguments from command line
"""
if args is None:
return
@@ -47,217 +56,56 @@ class Robot:
for key, value in runtime_params.items():
if value is not None:
self.config[key] = value
def get(self, key: str, default=None) -> Any:
"""Get configuration value with default fallback"""
return self.config.get(key, default)
def _init_components(self) -> None:
"""初始化核心组件"""
class RosAdapter:
"""Adapter for ROS communication"""
def __init__(self, config: RobotConfig):
"""
Initialize ROS adapter
Args:
config: Robot configuration
"""
self.config = config
self.bridge = CvBridge()
self.subscribers = {}
self.publishers = {}
self._validate_config()
def _validate_config(self) -> None:
"""验证配置完整性"""
required_sections = ['cameras', 'arm']
for section in required_sections:
if section not in self.config:
raise ValueError(f"Missing required config section: {section}")
def _init_data_structures(self) -> None:
"""初始化数据结构模板方法"""
# 相机数据
self.cameras = self.config.get('cameras', {})
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
# 深度数据
self.use_depth_image = self.config.get('use_depth_image', False)
if self.use_depth_image:
self.sync_depth_queues = {
name: deque(maxlen=2000)
for name, cam in self.cameras.items()
if 'depth_topic_name' in cam
}
def init_ros_node(self, node_name: str = None) -> None:
"""Initialize ROS node"""
if node_name is None:
node_name = self.config.get('ros_node_name', 'generic_robot_node')
rospy.init_node(node_name, anonymous=True)
# 机械臂数据
self.arms = self.config.get('arm', {})
if self.config.get('control_type', '') != 'record':
# 如果不是录制模式,则仅初始化从机械臂数据队列
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
else:
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
# 机器人基座数据
self.use_robot_base = self.config.get('use_robot_base', False)
if self.use_robot_base:
self.sync_base_queue = deque(maxlen=2000)
def init_ros(self) -> None:
"""初始化ROS订阅的模板方法"""
rospy.init_node(
f"{self.config.get('ros_node_name', 'generic_robot_node')}",
anonymous=True
def create_subscriber(self, topic: str, msg_type, callback, queue_size: int = 1000, tcp_nodelay: bool = True):
"""Create a ROS subscriber"""
subscriber = rospy.Subscriber(
topic,
msg_type,
callback,
queue_size=queue_size,
tcp_nodelay=tcp_nodelay
)
return subscriber
self._setup_camera_subscribers()
self._setup_arm_subscribers_publishers()
self._setup_base_subscriber()
self._log_ros_status()
def init_features(self):
"""
根据YAML配置自动生成features结构
"""
self.features = {}
def create_publisher(self, topic: str, msg_type, queue_size: int = 10):
"""Create a ROS publisher"""
publisher = rospy.Publisher(
topic,
msg_type,
queue_size=queue_size
)
return publisher
# 初始化相机特征
self._init_camera_features()
# 初始化机械臂特征
self._init_state_features()
self._init_action_features()
# 初始化基座特征(如果启用)
if self.use_robot_base:
self._init_base_features()
import pprint
pprint.pprint(self.features, indent=4)
def _init_camera_features(self):
"""处理所有相机特征"""
for cam_name, cam_config in self.cameras.items():
# 普通图像
self.features[f"observation.images.{cam_name}"] = {
"dtype": "video" if self.config.get("video", False) else "image",
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
"names": ["height", "width", "channel"],
# "video_info": {
# "video.fps": cam_config.get("fps", 30.0),
# "video.codec": cam_config.get("codec", "av1"),
# "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"),
# "video.is_depth_map": False,
# "has_audio": False
# }
}
if self.config.get("use_depth_image", False):
self.features[f"observation.images.depth_{cam_name}"] = {
"dtype": "uint16",
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
"names": ["height", "width"],
}
def _init_state_features(self):
state = self.config.get('state', {})
# 状态特征
self.features["observation.state"] = {
"dtype": "float32",
"shape": (len(state.get('motors', "")),),
"names": {"motors": state.get('motors', "")}
}
if self.config.get('velocity'):
velocity = self.config.get('velocity', "")
self.features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(velocity.get('motors', "")),),
"names": {"motors": velocity.get('motors', "")}
}
if self.config.get('effort'):
effort = self.config.get('effort', "")
self.features["observation.effort"] = {
"dtype": "float32",
"shape": (len(effort.get('motors', "")),),
"names": {"motors": effort.get('motors', "")}
}
def _init_action_features(self):
action = self.config.get('action', {})
# 状态特征
self.features["action"] = {
"dtype": "float32",
"shape": (len(action.get('motors', "")),),
"names": {"motors": action.get('motors', "")}
}
def _init_base_features(self):
"""处理基座特征"""
self.features["observation.base_vel"] = {
"dtype": "float32",
"shape": (2,),
"names": ["linear_x", "angular_z"]
}
def _setup_camera_subscribers(self) -> None:
"""设置相机订阅者"""
for cam_name, cam_config in self.cameras.items():
if 'img_topic_name' in cam_config:
self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber(
cam_config['img_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=False),
queue_size=1000,
tcp_nodelay=True
)
if self.use_depth_image and 'depth_topic_name' in cam_config:
self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber(
cam_config['depth_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=True),
queue_size=1000,
tcp_nodelay=True
)
def _setup_arm_subscribers_publishers(self) -> None:
"""设置机械臂订阅者"""
# 当为record模式时,主从机械臂都需要订阅
# 否则只订阅从机械臂,但向主机械臂发布
if self.config.get('control_type', '') == 'record':
for arm_name, arm_config in self.arms.items():
if 'topic_name' in arm_config:
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name),
queue_size=1000,
tcp_nodelay=True
)
else:
for arm_name, arm_config in self.arms.items():
if 'puppet' in arm_name:
self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name),
queue_size=1000,
tcp_nodelay=True
)
if 'master' in arm_name:
self.publishers[f"arm_{arm_name}"] = rospy.Publisher(
arm_config['topic_name'],
JointState,
queue_size=10
)
def _setup_base_subscriber(self) -> None:
"""设置基座订阅者"""
if self.use_robot_base and 'robot_base' in self.config:
self.subscribers['base'] = rospy.Subscriber(
self.config['robot_base']['topic_name'],
Odometry,
self.robot_base_callback,
queue_size=1000,
tcp_nodelay=True
)
def _log_ros_status(self) -> None:
"""记录ROS状态"""
def log_status(self) -> None:
"""Log ROS connection status"""
rospy.loginfo("\n=== ROS订阅状态 ===")
rospy.loginfo(f"已初始化节点: {rospy.get_name()}")
rospy.loginfo("活跃的订阅者:")
@@ -265,8 +113,74 @@ class Robot:
rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}")
rospy.loginfo("=================")
class RobotSensors:
"""Management of robot sensors (cameras, depth sensors)"""
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
"""
Initialize robot sensors
Args:
config: Robot configuration
ros_adapter: ROS communication adapter
"""
self.config = config
self.ros_adapter = ros_adapter
self.bridge = ros_adapter.bridge
# Camera data
self.cameras = config.get('cameras', {})
self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras}
# Depth data
self.use_depth_image = config.get('use_depth_image', False)
if self.use_depth_image:
self.sync_depth_queues = {
name: deque(maxlen=2000)
for name, cam in self.cameras.items()
if 'depth_topic_name' in cam
}
# Robot base data
self.use_robot_base = config.get('use_robot_base', False)
if self.use_robot_base:
self.sync_base_queue = deque(maxlen=2000)
def setup_subscribers(self) -> None:
"""Set up ROS subscribers for sensors"""
self._setup_camera_subscribers()
if self.use_robot_base:
self._setup_base_subscriber()
def _setup_camera_subscribers(self) -> None:
"""Set up camera subscribers"""
for cam_name, cam_config in self.cameras.items():
if 'img_topic_name' in cam_config:
self.ros_adapter.subscribers[f"camera_{cam_name}"] = self.ros_adapter.create_subscriber(
cam_config['img_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=False)
)
if self.use_depth_image and 'depth_topic_name' in cam_config:
self.ros_adapter.subscribers[f"depth_{cam_name}"] = self.ros_adapter.create_subscriber(
cam_config['depth_topic_name'],
Image,
self._make_camera_callback(cam_name, is_depth=True)
)
def _setup_base_subscriber(self) -> None:
"""Set up base subscriber"""
if 'robot_base' in self.config.config:
self.ros_adapter.subscribers['base'] = self.ros_adapter.create_subscriber(
self.config.get('robot_base')['topic_name'],
Odometry,
self.robot_base_callback
)
def _make_camera_callback(self, cam_name: str, is_depth: bool = False):
"""生成相机回调函数工厂方法"""
"""Generate camera callback factory method"""
def callback(msg):
try:
target_queue = (
@@ -281,8 +195,105 @@ class Robot:
rospy.logerr(f"Camera {cam_name} callback error: {str(e)}")
return callback
def robot_base_callback(self, msg):
"""Base callback default implementation"""
if len(self.sync_base_queue) >= 2000:
self.sync_base_queue.popleft()
self.sync_base_queue.append(msg)
def init_features(self) -> Dict[str, Any]:
"""Initialize sensor features"""
features = {}
# Initialize camera features
self._init_camera_features(features)
# Initialize base features (if enabled)
if self.use_robot_base:
self._init_base_features(features)
return features
def _init_camera_features(self, features: Dict[str, Any]) -> None:
"""Process all camera features"""
for cam_name, cam_config in self.cameras.items():
# Regular images
features[f"observation.images.{cam_name}"] = {
"dtype": "video" if self.config.get("video", False) else "image",
"shape": cam_config.get("rgb_shape", [480, 640, 3]),
"names": ["height", "width", "channel"],
}
if self.config.get("use_depth_image", False):
features[f"observation.images.depth_{cam_name}"] = {
"dtype": "uint16",
"shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1),
"names": ["height", "width"],
}
def _init_base_features(self, features: Dict[str, Any]) -> None:
"""Process base features"""
features["observation.base_vel"] = {
"dtype": "float32",
"shape": (2,),
"names": ["linear_x", "angular_z"]
}
class RobotActuators:
"""Management of robot actuators (arms, base)"""
def __init__(self, config: RobotConfig, ros_adapter: RosAdapter):
"""
Initialize robot actuators
Args:
config: Robot configuration
ros_adapter: ROS communication adapter
"""
self.config = config
self.ros_adapter = ros_adapter
# Arm data
self.arms = config.get('arm', {})
if config.get('control_type', '') != 'record':
# If not in record mode, only initialize puppet arm queues
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name}
else:
self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms}
def setup_subscribers_publishers(self) -> None:
"""Set up ROS subscribers and publishers for actuators"""
self._setup_arm_subscribers_publishers()
def _setup_arm_subscribers_publishers(self) -> None:
"""Set up arm subscribers and publishers"""
# When in record mode, subscribe to both master and puppet arms
# Otherwise only subscribe to puppet arms, but publish to master arms
if self.config.get('control_type', '') == 'record':
for arm_name, arm_config in self.arms.items():
if 'topic_name' in arm_config:
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name)
)
else:
for arm_name, arm_config in self.arms.items():
if 'puppet' in arm_name:
self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber(
arm_config['topic_name'],
JointState,
self._make_arm_callback(arm_name)
)
if 'master' in arm_name:
self.ros_adapter.publishers[f"arm_{arm_name}"] = self.ros_adapter.create_publisher(
arm_config['topic_name'],
JointState
)
def _make_arm_callback(self, arm_name: str):
"""生成机械臂回调函数工厂方法"""
"""Generate arm callback factory method"""
def callback(msg):
try:
if len(self.sync_arm_queues[arm_name]) >= 2000:
@@ -292,17 +303,74 @@ class Robot:
rospy.logerr(f"Arm {arm_name} callback error: {str(e)}")
return callback
def robot_base_callback(self, msg):
"""基座回调默认实现"""
if len(self.sync_base_queue) >= 2000:
self.sync_base_queue.popleft()
self.sync_base_queue.append(msg)
def init_features(self) -> Dict[str, Any]:
"""Initialize actuator features"""
features = {}
# Initialize arm features
self._init_state_features(features)
self._init_action_features(features)
return features
def _init_state_features(self, features: Dict[str, Any]) -> None:
"""Initialize state features"""
state = self.config.get('state', {})
# State features
features["observation.state"] = {
"dtype": "float32",
"shape": (len(state.get('motors', "")),),
"names": {"motors": state.get('motors', "")}
}
def warmup(self, timeout: float = 10.0) -> bool:
"""Wait until all data queues have at least 20 messages.
if self.config.get('velocity'):
velocity = self.config.get('velocity', "")
features["observation.velocity"] = {
"dtype": "float32",
"shape": (len(velocity.get('motors', "")),),
"names": {"motors": velocity.get('motors', "")}
}
if self.config.get('effort'):
effort = self.config.get('effort', "")
features["observation.effort"] = {
"dtype": "float32",
"shape": (len(effort.get('motors', "")),),
"names": {"motors": effort.get('motors', "")}
}
def _init_action_features(self, features: Dict[str, Any]) -> None:
"""Initialize action features"""
action = self.config.get('action', {})
features["action"] = {
"dtype": "float32",
"shape": (len(action.get('motors', "")),),
"names": {"motors": action.get('motors', "")}
}
class RobotDataManager:
"""Management of robot data collection and synchronization"""
def __init__(self, config: RobotConfig, sensors: RobotSensors, actuators: RobotActuators):
"""
Initialize robot data manager
Args:
timeout: Maximum time to wait in seconds before giving up
config: Robot configuration
sensors: Robot sensors component
actuators: Robot actuators component
"""
self.config = config
self.sensors = sensors
self.actuators = actuators
def warmup(self, timeout: float = 10.0) -> bool:
"""
Wait until all data queues have sufficient messages
Args:
timeout: Maximum time to wait in seconds
Returns:
bool: True if warmup succeeded, False if timed out
@@ -323,31 +391,24 @@ class Robot:
all_ready = True
# Check camera image queues
for cam_name in self.cameras:
if len(self.sync_img_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)")
for cam_name in self.sensors.cameras:
if len(self.sensors.sync_img_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)")
all_ready = False
break
# Check depth queues if enabled
if self.use_depth_image:
for cam_name in self.sync_depth_queues:
if len(self.sync_depth_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)")
if self.sensors.use_depth_image:
for cam_name in self.sensors.sync_depth_queues:
if len(self.sensors.sync_depth_queues[cam_name]) < 50:
rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)")
all_ready = False
break
# # Check arm queues
# for arm_name in self.arms:
# if len(self.sync_arm_queues[arm_name]) < 20:
# rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)")
# all_ready = False
# break
# Check base queue if enabled
if self.use_robot_base:
if len(self.sync_base_queue) < 20:
rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)")
if self.sensors.use_robot_base:
if len(self.sensors.sync_base_queue) < 20:
rospy.loginfo(f"Waiting for base (current: {len(self.sensors.sync_base_queue)}/20)")
all_ready = False
# If all queues are ready, return success
@@ -357,16 +418,4 @@ class Robot:
rate.sleep()
return False
def get_frame(self) -> Optional[Dict[str, Any]]:
"""获取同步帧数据的模板方法"""
raise NotImplementedError("Subclasses must implement get_frame()")
def process(self) -> tuple:
"""主处理循环的模板方法"""
raise NotImplementedError("Subclasses must implement process()")
return False

View File

@@ -0,0 +1,136 @@
import yaml
from typing import Dict, Any, Optional, List
import argparse
from .robot_components import RobotConfig, RosAdapter, RobotSensors, RobotActuators, RobotDataManager
class Robot:
def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None):
"""
机器人基类,处理通用初始化逻辑
Args:
config_file: YAML配置文件路径
args: 运行时参数
"""
# 初始化组件
self.config = RobotConfig(config_file)
self.config.merge_runtime_args(args)
self.ros_adapter = RosAdapter(self.config)
self.sensors = RobotSensors(self.config, self.ros_adapter)
self.actuators = RobotActuators(self.config, self.ros_adapter)
self.data_manager = RobotDataManager(self.config, self.sensors, self.actuators)
# 初始化ROS和特征
self.init_ros()
self.init_features()
self.warmup()
def get(self, key: str, default=None) -> Any:
"""获取配置值"""
return self.config.get(key, default)
@property
def bridge(self):
"""获取CV桥接器"""
return self.ros_adapter.bridge
@property
def subscribers(self):
"""获取订阅者"""
return self.ros_adapter.subscribers
@property
def publishers(self):
"""获取发布者"""
return self.ros_adapter.publishers
@property
def cameras(self):
"""获取相机配置"""
return self.sensors.cameras
@property
def arms(self):
"""获取机械臂配置"""
return self.actuators.arms
@property
def sync_img_queues(self):
"""获取图像队列"""
return self.sensors.sync_img_queues
@property
def sync_depth_queues(self):
"""获取深度图像队列"""
return self.sensors.sync_depth_queues if hasattr(self.sensors, 'sync_depth_queues') else {}
@property
def sync_arm_queues(self):
"""获取机械臂队列"""
return self.actuators.sync_arm_queues
@property
def sync_base_queue(self):
"""获取基座队列"""
return self.sensors.sync_base_queue if hasattr(self.sensors, 'sync_base_queue') else None
@property
def use_depth_image(self):
"""是否使用深度图像"""
return self.sensors.use_depth_image
@property
def use_robot_base(self):
"""是否使用机器人基座"""
return self.sensors.use_robot_base
def init_ros(self) -> None:
"""初始化ROS订阅的模板方法"""
self.ros_adapter.init_ros_node()
# 设置传感器和执行器的订阅者和发布者
self.sensors.setup_subscribers()
self.actuators.setup_subscribers_publishers()
# 记录ROS状态
self.ros_adapter.log_status()
def init_features(self):
"""
根据YAML配置自动生成features结构
"""
# 合并传感器和执行器的特征
self.features = {}
self.features.update(self.sensors.init_features())
self.features.update(self.actuators.init_features())
import pprint
pprint.pprint(self.features, indent=4)
def warmup(self, timeout: float = 10.0) -> bool:
"""Wait until all data queues have at least 20 messages.
Args:
timeout: Maximum time to wait in seconds before giving up
Returns:
bool: True if warmup succeeded, False if timed out
"""
return self.data_manager.warmup(timeout)
def get_frame(self) -> Optional[Dict[str, Any]]:
"""获取同步帧数据的模板方法"""
raise NotImplementedError("Subclasses must implement get_frame()")
def process(self) -> tuple:
"""主处理循环的模板方法"""
raise NotImplementedError("Subclasses must implement process()")

View File

@@ -0,0 +1,59 @@
import yaml
import argparse
from typing import Dict, List, Any, Optional, Type
from .rosrobot import Robot
from .agilex_robot import AgilexRobot
class RobotFactory:
"""Factory for creating robot instances based on configuration"""
# 注册表,用于存储可用的机器人类型
_registry = {}
@classmethod
def register(cls, robot_type: str, robot_class: Type[Robot]) -> None:
"""
注册新的机器人类型
Args:
robot_type: 机器人类型标识符
robot_class: 机器人类实现
"""
cls._registry[robot_type] = robot_class
@classmethod
def create(cls, config_file: str, args: Optional[argparse.Namespace] = None) -> Robot:
"""
根据配置文件自动创建合适的机器人实例
Args:
config_file: 配置文件路径
args: 运行时参数
Returns:
Robot: 创建的机器人实例
Raises:
ValueError: 如果指定的机器人类型不受支持
"""
with open(config_file, 'r') as f:
config = yaml.safe_load(f)
robot_type = config.get('robot_type', 'agilex')
# 如果注册表为空,注册默认机器人类型
if not cls._registry:
cls.register('agilex', AgilexRobot)
cls.register('aloha_agilex', AgilexRobot) # 别名支持
# 从注册表中查找机器人类
if robot_type in cls._registry:
return cls._registry[robot_type](config_file, args)
else:
raise ValueError(f"Unsupported robot type: {robot_type}. Available types: {list(cls._registry.keys())}")
# 注册可用的机器人类型
RobotFactory.register('agilex', AgilexRobot)
RobotFactory.register('aloha_agilex', AgilexRobot) # 别名支持

View File

@@ -0,0 +1,146 @@
robot_type: aloha_agilex
ros_node_name: record_episodes
cameras:
cam_front:
img_topic_name: /camera_f/color/image_raw
depth_topic_name: /camera_f/depth/image_raw
width: 480
height: 640
rgb_shape: [480, 640, 3]
cam_left:
img_topic_name: /camera_l/color/image_raw
depth_topic_name: /camera_l/depth/image_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
cam_right:
img_topic_name: /camera_r/color/image_raw
depth_topic_name: /camera_r/depth/image_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
cam_high:
img_topic_name: /camera/color/image_raw
depth_topic_name: /camera/depth/image_rect_raw
rgb_shape: [480, 640, 3]
width: 480
height: 640
arm:
master_left:
topic_name: /master/joint_left
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none"
]
master_right:
topic_name: /master/joint_right
motors: [
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
puppet_left:
topic_name: /puppet/joint_left
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none"
]
puppet_right:
topic_name: /puppet/joint_right
motors: [
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
# follow the joint name in ros
state:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
velocity:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
effort:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]
action:
motors: [
"left_joint0",
"left_joint1",
"left_joint2",
"left_joint3",
"left_joint4",
"left_joint5",
"left_none",
"right_joint0",
"right_joint1",
"right_joint2",
"right_joint3",
"right_joint4",
"right_joint5",
"right_none"
]

769
lerobot_aloha/inference.py Normal file
View File

@@ -0,0 +1,769 @@
#!/home/lin/software/miniconda3/envs/aloha/bin/python
# -- coding: UTF-8
"""
#!/usr/bin/python3
"""
import torch
import numpy as np
import os
import pickle
import argparse
from einops import rearrange
import collections
from collections import deque
import rospy
from std_msgs.msg import Header
from geometry_msgs.msg import Twist
from sensor_msgs.msg import JointState, Image
from nav_msgs.msg import Odometry
from cv_bridge import CvBridge
import time
import threading
import math
import threading
import sys
sys.path.append("./")
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']}
inference_thread = None
inference_lock = threading.Lock()
inference_actions = None
inference_timestep = None
def actions_interpolation(args, pre_action, actions, stats):
steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0)
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
result = [pre_action]
post_action = post_process(actions[0])
# print("pre_action:", pre_action[7:])
# print("actions_interpolation1:", post_action[:, 7:])
max_diff_index = 0
max_diff = -1
for i in range(post_action.shape[0]):
diff = 0
for j in range(pre_action.shape[0]):
if j == 6 or j == 13:
continue
diff += math.fabs(pre_action[j] - post_action[i][j])
if diff > max_diff:
max_diff = diff
max_diff_index = i
for i in range(max_diff_index, post_action.shape[0]):
step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])])
inter = np.linspace(result[-1], post_action[i], step+2)
result.extend(inter[1:])
while len(result) < args.chunk_size+1:
result.append(result[-1])
result = np.array(result)[1:args.chunk_size+1]
# print("actions_interpolation2:", result.shape, result[:, 7:])
result = pre_process(result)
result = result[np.newaxis, :]
return result
def get_model_config(args):
# 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。
set_seed(1)
# 如果是ACT策略
# fixed parameters
if args.policy_class == 'ACT':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': args.chunk_size, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base,
'kl_weight': args.kl_weight, # kl散度权重
'hidden_dim': args.hidden_dim, # 隐藏层维度
'dim_feedforward': args.dim_feedforward,
'enc_layers': args.enc_layers,
'dec_layers': args.dec_layers,
'nheads': args.nheads,
'dropout': args.dropout,
'pre_norm': args.pre_norm
}
elif args.policy_class == 'CNNMLP':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': 1, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base
}
elif args.policy_class == 'Diffusion':
policy_config = {'lr': args.lr,
'lr_backbone': args.lr_backbone,
'backbone': args.backbone,
'masks': args.masks,
'weight_decay': args.weight_decay,
'dilation': args.dilation,
'position_embedding': args.position_embedding,
'loss_function': args.loss_function,
'chunk_size': args.chunk_size, # 查询
'camera_names': task_config['camera_names'],
'use_depth_image': args.use_depth_image,
'use_robot_base': args.use_robot_base,
'observation_horizon': args.observation_horizon,
'action_horizon': args.action_horizon,
'num_inference_timesteps': args.num_inference_timesteps,
'ema_power': args.ema_power
}
else:
raise NotImplementedError
config = {
'ckpt_dir': args.ckpt_dir,
'ckpt_name': args.ckpt_name,
'ckpt_stats_name': args.ckpt_stats_name,
'episode_len': args.max_publish_step,
'state_dim': args.state_dim,
'policy_class': args.policy_class,
'policy_config': policy_config,
'temporal_agg': args.temporal_agg,
'camera_names': task_config['camera_names'],
}
return config
def make_policy(policy_class, policy_config):
if policy_class == 'ACT':
policy = ACTPolicy(policy_config)
elif policy_class == 'CNNMLP':
policy = CNNMLPPolicy(policy_config)
elif policy_class == 'Diffusion':
policy = DiffusionPolicy(policy_config)
else:
raise NotImplementedError
return policy
def get_image(observation, camera_names):
curr_images = []
for cam_name in camera_names:
curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w')
curr_images.append(curr_image)
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def get_depth_image(observation, camera_names):
curr_images = []
for cam_name in camera_names:
curr_images.append(observation['images_depth'][cam_name])
curr_image = np.stack(curr_images, axis=0)
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
return curr_image
def inference_process(args, config, ros_operator, policy, stats, t, pre_action):
global inference_lock
global inference_actions
global inference_timestep
print_flag = True
pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"]
rate = rospy.Rate(args.publish_rate)
while True and not rospy.is_shutdown():
result = ros_operator.get_frame()
if not result:
if print_flag:
print("syn fail")
print_flag = False
rate.sleep()
continue
print_flag = True
(img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base) = result
obs = collections.OrderedDict()
image_dict = dict()
image_dict[config['camera_names'][0]] = img_front
image_dict[config['camera_names'][1]] = img_left
image_dict[config['camera_names'][2]] = img_right
obs['images'] = image_dict
if args.use_depth_image:
image_depth_dict = dict()
image_depth_dict[config['camera_names'][0]] = img_front_depth
image_depth_dict[config['camera_names'][1]] = img_left_depth
image_depth_dict[config['camera_names'][2]] = img_right_depth
obs['images_depth'] = image_depth_dict
obs['qpos'] = np.concatenate(
(np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0)
obs['qvel'] = np.concatenate(
(np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0)
obs['effort'] = np.concatenate(
(np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0)
if args.use_robot_base:
obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z]
obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0)
else:
obs['base_vel'] = [0.0, 0.0]
# qpos_numpy = np.array(obs['qpos'])
# 归一化处理qpos 并转到cuda
qpos = pre_pos_process(obs['qpos'])
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
# 当前图像curr_image获取图像
curr_image = get_image(obs, config['camera_names'])
curr_depth_image = None
if args.use_depth_image:
curr_depth_image = get_depth_image(obs, config['camera_names'])
start_time = time.time()
all_actions = policy(curr_image, curr_depth_image, qpos)
end_time = time.time()
print("model cost time: ", end_time -start_time)
inference_lock.acquire()
inference_actions = all_actions.cpu().detach().numpy()
if pre_action is None:
pre_action = obs['qpos']
# print("obs['qpos']:", obs['qpos'][7:])
if args.use_actions_interpolation:
inference_actions = actions_interpolation(args, pre_action, inference_actions, stats)
inference_timestep = t
inference_lock.release()
break
def model_inference(args, config, ros_operator, save_episode=True):
global inference_lock
global inference_actions
global inference_timestep
global inference_thread
set_seed(1000)
# 1 创建模型数据 继承nn.Module
policy = make_policy(config['policy_class'], config['policy_config'])
# print("model structure\n", policy.model)
# 2 加载模型权重
ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name'])
state_dict = torch.load(ckpt_path)
new_state_dict = {}
for key, value in state_dict.items():
if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]:
continue
if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]:
continue
new_state_dict[key] = value
loading_status = policy.deserialize(new_state_dict)
if not loading_status:
print("ckpt path not exist")
return False
# 3 模型设置为cuda模式和验证模式
policy.cuda()
policy.eval()
# 4 加载统计值
stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name'])
# 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
# 数据预处理和后处理函数定义
pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
max_publish_step = config['episode_len']
chunk_size = config['policy_config']['chunk_size']
# 发布基础的姿态
left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875]
right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875]
left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258]
right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883]
ros_operator.puppet_arm_publish_continuous(left0, right0)
input("Enter any key to continue :")
ros_operator.puppet_arm_publish_continuous(left1, right1)
action = None
# 推理
with torch.inference_mode():
while True and not rospy.is_shutdown():
# 每个回合的步数
t = 0
max_t = 0
rate = rospy.Rate(args.publish_rate)
if config['temporal_agg']:
all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']])
while t < max_publish_step and not rospy.is_shutdown():
# start_time = time.time()
# query policy
if config['policy_class'] == "ACT":
if t >= max_t:
pre_action = action
inference_thread = threading.Thread(target=inference_process,
args=(args, config, ros_operator,
policy, stats, t, pre_action))
inference_thread.start()
inference_thread.join()
inference_lock.acquire()
if inference_actions is not None:
inference_thread = None
all_actions = inference_actions
inference_actions = None
max_t = t + args.pos_lookahead_step
if config['temporal_agg']:
all_time_actions[[t], t:t + chunk_size] = all_actions
inference_lock.release()
if config['temporal_agg']:
actions_for_curr_step = all_time_actions[:, t]
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
k = 0.01
exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step)))
exp_weights = exp_weights / exp_weights.sum()
exp_weights = exp_weights[:, np.newaxis]
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True)
else:
if args.pos_lookahead_step != 0:
raw_action = all_actions[:, t % args.pos_lookahead_step]
else:
raw_action = all_actions[:, t % chunk_size]
else:
raise NotImplementedError
action = post_process(raw_action[0])
left_action = action[:7] # 取7维度
right_action = action[7:14]
ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread
if args.use_robot_base:
vel_action = action[14:16]
ros_operator.robot_base_publish(vel_action)
t += 1
# end_time = time.time()
# print("publish: ", t)
# print("time:", end_time - start_time)
# print("left_action:", left_action)
# print("right_action:", right_action)
rate.sleep()
class RosOperator:
def __init__(self, args):
self.robot_base_deque = None
self.puppet_arm_right_deque = None
self.puppet_arm_left_deque = None
self.img_front_deque = None
self.img_right_deque = None
self.img_left_deque = None
self.img_front_depth_deque = None
self.img_right_depth_deque = None
self.img_left_depth_deque = None
self.bridge = None
self.puppet_arm_left_publisher = None
self.puppet_arm_right_publisher = None
self.robot_base_publisher = None
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_lock = None
self.args = args
self.ctrl_state = False
self.ctrl_state_lock = threading.Lock()
self.init()
self.init_ros()
def init(self):
self.bridge = CvBridge()
self.img_left_deque = deque()
self.img_right_deque = deque()
self.img_front_deque = deque()
self.img_left_depth_deque = deque()
self.img_right_depth_deque = deque()
self.img_front_depth_deque = deque()
self.puppet_arm_left_deque = deque()
self.puppet_arm_right_deque = deque()
self.robot_base_deque = deque()
self.puppet_arm_publish_lock = threading.Lock()
self.puppet_arm_publish_lock.acquire()
def puppet_arm_publish(self, left, right):
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right
self.puppet_arm_right_publisher.publish(joint_state_msg)
def robot_base_publish(self, vel):
vel_msg = Twist()
vel_msg.linear.x = vel[0]
vel_msg.linear.y = 0
vel_msg.linear.z = 0
vel_msg.angular.x = 0
vel_msg.angular.y = 0
vel_msg.angular.z = vel[1]
self.robot_base_publisher.publish(vel_msg)
def puppet_arm_publish_continuous(self, left, right):
rate = rospy.Rate(self.args.publish_rate)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))]
right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))]
flag = True
step = 0
while flag and not rospy.is_shutdown():
if self.puppet_arm_publish_lock.acquire(False):
return
left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))]
right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))]
flag = False
for i in range(len(left)):
if left_diff[i] < self.args.arm_steps_length[i]:
left_arm[i] = left[i]
else:
left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i]
flag = True
for i in range(len(right)):
if right_diff[i] < self.args.arm_steps_length[i]:
right_arm[i] = right[i]
else:
right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i]
flag = True
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = left_arm
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = right_arm
self.puppet_arm_right_publisher.publish(joint_state_msg)
step += 1
print("puppet_arm_publish_continuous:", step)
rate.sleep()
def puppet_arm_publish_linear(self, left, right):
num_step = 100
rate = rospy.Rate(200)
left_arm = None
right_arm = None
while True and not rospy.is_shutdown():
if len(self.puppet_arm_left_deque) != 0:
left_arm = list(self.puppet_arm_left_deque[-1].position)
if len(self.puppet_arm_right_deque) != 0:
right_arm = list(self.puppet_arm_right_deque[-1].position)
if left_arm is None or right_arm is None:
rate.sleep()
continue
else:
break
traj_left_list = np.linspace(left_arm, left, num_step)
traj_right_list = np.linspace(right_arm, right, num_step)
for i in range(len(traj_left_list)):
traj_left = traj_left_list[i]
traj_right = traj_right_list[i]
traj_left[-1] = left[-1]
traj_right[-1] = right[-1]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称
joint_state_msg.position = traj_left
self.puppet_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = traj_right
self.puppet_arm_right_publisher.publish(joint_state_msg)
rate.sleep()
def puppet_arm_publish_continuous_thread(self, left, right):
if self.puppet_arm_publish_thread is not None:
self.puppet_arm_publish_lock.release()
self.puppet_arm_publish_thread.join()
self.puppet_arm_publish_lock.acquire(False)
self.puppet_arm_publish_thread = None
self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right))
self.puppet_arm_publish_thread.start()
def get_frame(self):
if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \
(self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)):
return False
if self.args.use_depth_image:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(),
self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()])
else:
frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()])
if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time:
return False
if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time:
return False
if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time):
return False
if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time):
return False
while self.img_left_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_deque.popleft()
img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough')
while self.img_right_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_deque.popleft()
img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough')
while self.img_front_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_deque.popleft()
img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough')
while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_left_deque.popleft()
puppet_arm_left = self.puppet_arm_left_deque.popleft()
while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time:
self.puppet_arm_right_deque.popleft()
puppet_arm_right = self.puppet_arm_right_deque.popleft()
img_left_depth = None
if self.args.use_depth_image:
while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_left_depth_deque.popleft()
img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough')
img_right_depth = None
if self.args.use_depth_image:
while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_right_depth_deque.popleft()
img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough')
img_front_depth = None
if self.args.use_depth_image:
while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time:
self.img_front_depth_deque.popleft()
img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough')
robot_base = None
if self.args.use_robot_base:
while self.robot_base_deque[0].header.stamp.to_sec() < frame_time:
self.robot_base_deque.popleft()
robot_base = self.robot_base_deque.popleft()
return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth,
puppet_arm_left, puppet_arm_right, robot_base)
def img_left_callback(self, msg):
if len(self.img_left_deque) >= 2000:
self.img_left_deque.popleft()
self.img_left_deque.append(msg)
def img_right_callback(self, msg):
if len(self.img_right_deque) >= 2000:
self.img_right_deque.popleft()
self.img_right_deque.append(msg)
def img_front_callback(self, msg):
if len(self.img_front_deque) >= 2000:
self.img_front_deque.popleft()
self.img_front_deque.append(msg)
def img_left_depth_callback(self, msg):
if len(self.img_left_depth_deque) >= 2000:
self.img_left_depth_deque.popleft()
self.img_left_depth_deque.append(msg)
def img_right_depth_callback(self, msg):
if len(self.img_right_depth_deque) >= 2000:
self.img_right_depth_deque.popleft()
self.img_right_depth_deque.append(msg)
def img_front_depth_callback(self, msg):
if len(self.img_front_depth_deque) >= 2000:
self.img_front_depth_deque.popleft()
self.img_front_depth_deque.append(msg)
def puppet_arm_left_callback(self, msg):
if len(self.puppet_arm_left_deque) >= 2000:
self.puppet_arm_left_deque.popleft()
self.puppet_arm_left_deque.append(msg)
def puppet_arm_right_callback(self, msg):
if len(self.puppet_arm_right_deque) >= 2000:
self.puppet_arm_right_deque.popleft()
self.puppet_arm_right_deque.append(msg)
def robot_base_callback(self, msg):
if len(self.robot_base_deque) >= 2000:
self.robot_base_deque.popleft()
self.robot_base_deque.append(msg)
def ctrl_callback(self, msg):
self.ctrl_state_lock.acquire()
self.ctrl_state = msg.data
self.ctrl_state_lock.release()
def get_ctrl_state(self):
self.ctrl_state_lock.acquire()
state = self.ctrl_state
self.ctrl_state_lock.release()
return state
def init_ros(self):
rospy.init_node('joint_state_publisher', anonymous=True)
rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True)
if self.args.use_depth_image:
rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True)
rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True)
self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10)
self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10)
self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10)
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False)
parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False)
parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False)
parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False)
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False)
parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False)
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False)
parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False)
parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False)
parser.add_argument('--dilation', action='store_true',
help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False)
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features", required=False)
parser.add_argument('--masks', action='store_true',
help="Train segmentation head if the flag is provided")
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False)
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False)
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False)
parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False)
parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False)
parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False)
parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False)
parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False)
parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False)
parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False)
parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False)
parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False)
parser.add_argument('--pre_norm', action='store_true', required=False)
parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic',
default='/camera_f/color/image_raw', required=False)
parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic',
default='/camera_l/color/image_raw', required=False)
parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic',
default='/camera_r/color/image_raw', required=False)
parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic',
default='/camera_f/depth/image_raw', required=False)
parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic',
default='/camera_l/depth/image_raw', required=False)
parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic',
default='/camera_r/depth/image_raw', required=False)
parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic',
default='/master/joint_left', required=False)
parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic',
default='/master/joint_right', required=False)
parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic',
default='/puppet/joint_left', required=False)
parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic',
default='/puppet/joint_right', required=False)
parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic',
default='/odom_raw', required=False)
parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic',
default='/cmd_vel', required=False)
parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base',
default=False, required=False)
parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate',
default=40, required=False)
parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step',
default=0, required=False)
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size',
default=32, required=False)
parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length',
default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False)
parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation',
default=False, required=False)
parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image',
default=False, required=False)
# for Diffusion
parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False)
parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False)
parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False)
parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False)
args = parser.parse_args()
return args
def main():
args = get_arguments()
ros_operator = RosOperator(args)
config = get_model_config(args)
model_inference(args, config, ros_operator, save_episode=True)
if __name__ == '__main__':
main()
# python act/inference.py --ckpt_dir ~/train0314/

View File

@@ -0,0 +1,33 @@
import pandas as pd
def read_and_print_parquet_row(file_path, row_index=0):
"""
读取Parquet文件并打印指定行的数据
参数:
file_path (str): Parquet文件路径
row_index (int): 要打印的行索引默认为第0行
"""
try:
# 读取Parquet文件
df = pd.read_parquet(file_path)
# 检查行索引是否有效
if row_index >= len(df):
print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)")
return
# 打印指定行数据
print(f"文件: {file_path}")
print(f"{row_index} 行数据:\n{'-'*30}")
print(df.iloc[row_index])
except FileNotFoundError:
print(f"错误: 文件 {file_path} 不存在")
except Exception as e:
print(f"读取失败: {str(e)}")
# 示例用法
if __name__ == "__main__":
file_path = "example.parquet" # 替换为你的Parquet文件路径
read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行

View File

@@ -0,0 +1,112 @@
#coding=utf-8
import os
import numpy as np
import cv2
import h5py
import argparse
import rospy
from cv_bridge import CvBridge
from std_msgs.msg import Header
from sensor_msgs.msg import Image, JointState
from geometry_msgs.msg import Twist
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
def main(args):
rospy.init_node("replay_node")
bridge = CvBridge()
# img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10)
# img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10)
# img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10)
# puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10)
# puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10)
master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10)
master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10)
# robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10)
# dataset_dir = args.dataset_dir
# episode_idx = args.episode_idx
# task_name = args.task_name
# dataset_name = f'episode_{episode_idx}'
dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode])
actions = dataset.hf_dataset.select_columns("action")
velocitys = dataset.hf_dataset.select_columns("observation.velocity")
efforts = dataset.hf_dataset.select_columns("observation.effort")
origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279]
origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279]
joint_state_msg = JointState()
joint_state_msg.header = Header()
joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称
twist_msg = Twist()
rate = rospy.Rate(args.fps)
# qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name)
last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647]
last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125]
rate = rospy.Rate(50)
for idx in range(len(actions)):
action = actions[idx]['action'].detach().cpu().numpy()
velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy()
effort = efforts[idx]['observation.effort'].detach().cpu().numpy()
if(rospy.is_shutdown()):
break
new_actions = np.linspace(last_action, action, 5) # 插值
new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值
new_efforts = np.linspace(last_effort, effort, 5) # 插值
last_action = action
last_velocity = velocity
last_effort = effort
for act in new_actions:
print(np.round(act[:7], 4))
cur_timestamp = rospy.Time.now() # 设置时间戳
joint_state_msg.header.stamp = cur_timestamp
joint_state_msg.position = act[:7]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[:7]
master_arm_left_publisher.publish(joint_state_msg)
joint_state_msg.position = act[7:]
joint_state_msg.velocity = last_velocity[:7]
joint_state_msg.effort = last_effort[7:]
master_arm_right_publisher.publish(joint_state_msg)
if(rospy.is_shutdown()):
break
rate.sleep()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic',
# default='/master/joint_left', required=False)
# parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic',
# default='/master/joint_right', required=False)
args = parser.parse_args()
args.repo_id = "tangger/test"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data1"
args.episode = 1 # replay episode
args.master_arm_left_topic = "/master/joint_left"
args.master_arm_right_topic = "/master/joint_right"
args.fps = 30
main(args)
# python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0

70
lerobot_aloha/test.py Normal file
View File

@@ -0,0 +1,70 @@
from lerobot.common.policies.act.modeling_act import ACTPolicy
from lerobot.common.robot_devices.utils import busy_wait
import time
import argparse
from agilex_robot import AgilexRobot
import torch
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "tangger/test"
args.root = "./data2"
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 50
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "test test"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy= None
args.teleoprate = False
return args
cfg = get_arguments()
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
inference_time_s = 360
fps = 30
device = "cuda" # TODO: On Mac, use "mps" or "cpu"
ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model"
policy = ACTPolicy.from_pretrained(ckpt_path)
policy.to(device)
for _ in range(inference_time_s * fps):
start_time = time.perf_counter()
# Read the follower state and access the frames from the cameras
observation = robot.capture_observation()
if observation is None:
print("Observation is None, skipping...")
continue
# Convert to pytorch format: channel first and float32 in [0,1]
# with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
# Order the robot to move
robot.send_action(action)
dt_s = time.perf_counter() - start_time
busy_wait(1 / fps - dt_s)