diff --git a/.gitignore b/.gitignore index 8e255c0..a587ac0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ cobot_magic/ -librealsense/ \ No newline at end of file +librealsense/ +data*/ \ No newline at end of file diff --git a/collect_data/__pycache__/agilex_robot.cpython-310.pyc b/collect_data/__pycache__/agilex_robot.cpython-310.pyc index 6eb2734..b8b4a33 100644 Binary files a/collect_data/__pycache__/agilex_robot.cpython-310.pyc and b/collect_data/__pycache__/agilex_robot.cpython-310.pyc differ diff --git a/collect_data/__pycache__/robot_components.cpython-310.pyc b/collect_data/__pycache__/robot_components.cpython-310.pyc new file mode 100644 index 0000000..2f08b67 Binary files /dev/null and b/collect_data/__pycache__/robot_components.cpython-310.pyc differ diff --git a/collect_data/__pycache__/rosrobot.cpython-310.pyc b/collect_data/__pycache__/rosrobot.cpython-310.pyc index 17b14c6..67ba2c5 100644 Binary files a/collect_data/__pycache__/rosrobot.cpython-310.pyc and b/collect_data/__pycache__/rosrobot.cpython-310.pyc differ diff --git a/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc b/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc new file mode 100644 index 0000000..fc77f3f Binary files /dev/null and b/collect_data/__pycache__/rosrobot_factory.cpython-310.pyc differ diff --git a/collect_data/agilex.yaml b/collect_data/agilex.yaml index 5a9f04a..703b7e2 100644 --- a/collect_data/agilex.yaml +++ b/collect_data/agilex.yaml @@ -19,6 +19,12 @@ cameras: rgb_shape: [480, 640, 3] width: 480 height: 640 + cam_high: + img_topic_name: /camera/color/image_raw + depth_topic_name: /camera/depth/image_rect_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 arm: master_left: diff --git a/collect_data/collect_data_lerobot.py b/collect_data/collect_data_lerobot.py index 5a6cd22..e880e54 100644 --- a/collect_data/collect_data_lerobot.py +++ b/collect_data/collect_data_lerobot.py @@ -408,7 +408,7 @@ def get_arguments(): args.fps = 30 args.resume = False args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" - args.root = "/home/ubuntu/LYT/aloha_lerobot/data4" + args.root = "./data5" args.episode = 0 # replay episode args.num_image_writer_processes = 0 args.num_image_writer_threads_per_camera = 4 @@ -429,51 +429,26 @@ def get_arguments(): -# @parser.wrap() -# def control_robot(cfg: ControlPipelineConfig): -# init_logging() -# logging.info(pformat(asdict(cfg))) - -# # robot = make_robot_from_config(cfg.robot) -# from agilex_robot import AgilexRobot -# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) - -# if isinstance(cfg.control, RecordControlConfig): -# print(cfg.control) -# record(robot, cfg.control) -# elif isinstance(cfg.control, ReplayControlConfig): -# replay(robot, cfg.control) - -# # if robot.is_connected: -# # # Disconnect manually to avoid a "Core dump" during process -# # # termination due to camera threads not properly exiting. -# # robot.disconnect() - # @parser.wrap() def control_robot(cfg): - # robot = make_robot_from_config(cfg.robot) - from agilex_robot import AgilexRobot - robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + from rosrobot_factory import RobotFactory + # 使用工厂模式创建机器人实例 + robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/collect_data/agilex.yaml", args=cfg) if cfg.control_type == "record": record(robot, cfg) elif cfg.control_type == "replay": replay(robot, cfg) - # if robot.is_connected: - # # Disconnect manually to avoid a "Core dump" during process - # # termination due to camera threads not properly exiting. - # robot.disconnect() if __name__ == "__main__": cfg = get_arguments() control_robot(cfg) # control_robot() - # cfg = get_arguments() - # from agilex_robot import AgilexRobot - # robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + # 使用工厂模式创建机器人实例 + # robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) # print(robot.features.items()) # print([key for key, ft in robot.features.items() if ft["dtype"] == "video"]) # record(robot, cfg) diff --git a/collect_data/rosrobot_factory.py b/collect_data/rosrobot_factory.py deleted file mode 100644 index 6ca6bec..0000000 --- a/collect_data/rosrobot_factory.py +++ /dev/null @@ -1,26 +0,0 @@ -import yaml -import argparse -from typing import Dict, List, Any, Optional -from rosrobot import Robot -from agilex_robot import AgilexRobot - - -class RobotFactory: - @staticmethod - def create(config_file: str, args: Optional[argparse.Namespace] = None) -> Robot: - """ - 根据配置文件自动创建合适的机器人实例 - Args: - config_file: 配置文件路径 - args: 运行时参数 - """ - with open(config_file, 'r') as f: - config = yaml.safe_load(f) - - robot_type = config.get('robot_type', 'agilex') - - if robot_type == 'agilex': - return AgilexRobot(config_file, args) - # 可扩展其他机器人类型 - else: - raise ValueError(f"Unsupported robot type: {robot_type}") diff --git a/init_robot.bash b/init_robot.bash new file mode 100644 index 0000000..f77c8ca --- /dev/null +++ b/init_robot.bash @@ -0,0 +1,2 @@ +source ~/ros_noetic/devel_isolated/setup.bash +cd cobot_magic/remote_control-x86-can-v2 && ./tools/can.sh && ./tools/jgl_2follower.sh \ No newline at end of file diff --git a/lerobot b/lerobot new file mode 160000 index 0000000..1c873df --- /dev/null +++ b/lerobot @@ -0,0 +1 @@ +Subproject commit 1c873df5c0dd4dd9a81cbd90e07dd95a272ee3f7 diff --git a/lerobot_aloha/README.MD b/lerobot_aloha/README.MD new file mode 100644 index 0000000..9e4d14a --- /dev/null +++ b/lerobot_aloha/README.MD @@ -0,0 +1,3 @@ +python collect_data.py --robot.type=aloha --control.type=record --control.fps=30 --control.single_task="Grasp a lego block and put it in the bin." --control.repo_id=tangger/test --control.num_episodes=1 --control.root=./data + +python lerobot/scripts/train.py --dataset.repo_id=maic/move_tube_on_scale --policy.type=act --output_dir=outputs/train/act_move_tube_on_scale --job_name=act_move_tube_on_scale --policy.device=cuda --wandb.enable=true --dataset.root=/home/ubuntu/LYT/aloha_lerobot/data1 \ No newline at end of file diff --git a/lerobot_aloha/collect_data_lerobot.py b/lerobot_aloha/collect_data_lerobot.py new file mode 100644 index 0000000..8ee0a52 --- /dev/null +++ b/lerobot_aloha/collect_data_lerobot.py @@ -0,0 +1,461 @@ +import logging +import time +from dataclasses import asdict +from pprint import pformat +from pprint import pprint + +# from safetensors.torch import load_file, save_file +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.robot_devices.control_configs import ( + CalibrateControlConfig, + ControlPipelineConfig, + RecordControlConfig, + RemoteRobotConfig, + ReplayControlConfig, + TeleoperateControlConfig, +) +from lerobot.common.robot_devices.control_utils import ( + # init_keyboard_listener, + record_episode, + stop_recording, + is_headless +) +from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config +from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect +from lerobot.common.utils.utils import has_method, init_logging, log_say +from lerobot.common.utils.utils import get_safe_torch_device +from contextlib import nullcontext +from copy import copy +import torch +import rospy +import cv2 +from lerobot.configs import parser +from common.agilex_robot import AgilexRobot +from common.rosrobot_factory import RobotFactory + + +######################################################################################## +# Control modes +######################################################################################## + + +def predict_action(observation, policy, device, use_amp): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + +def control_loop( + robot, + control_time_s=None, + teleoperate=False, + display_cameras=False, + dataset: LeRobotDataset | None = None, + events=None, + policy = None, + fps: int | None = None, + single_task: str | None = None, +): + # TODO(rcadene): Add option to record logs + # if not robot.is_connected: + # robot.connect() + + if events is None: + events = {"exit_early": False} + + if control_time_s is None: + control_time_s = float("inf") + + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + if dataset is not None and fps is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + rate = rospy.Rate(fps) + print_flag = True + while timestamp < control_time_s and not rospy.is_shutdown(): + # print(timestamp < control_time_s) + # print(rospy.is_shutdown()) + start_loop_t = time.perf_counter() + + if teleoperate: + observation, action = robot.teleop_step() + if observation is None or action is None: + if print_flag: + print("sync data fail, retrying...\n") + print_flag = False + rate.sleep() + continue + else: + # pass + observation = robot.capture_observation() + if policy is not None: + pred_action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} + + if dataset is not None: + frame = {**observation, **action, "task": single_task} + dataset.add_frame(frame) + + # if display_cameras and not is_headless(): + # image_keys = [key for key in observation if "image" in key] + # for key in image_keys: + # if "depth" in key: + # pass + # else: + # cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + + # print(1) + # cv2.waitKey(1) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + + # 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整) + screen_width = 1920 + screen_height = 1080 + + # 计算窗口的排列方式 + num_images = len(image_keys) + max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640 + rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数 + columns = min(num_images, max_columns) # 实际使用的列数 + + # 遍历所有图像键并显示 + for idx, key in enumerate(image_keys): + if "depth" in key: + continue # 跳过深度图像 + + # 将图像从 RGB 转换为 BGR 格式 + image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + + # 创建窗口 + cv2.imshow(key, image) + + # 计算窗口位置 + window_width = 640 + window_height = 480 + row = idx // max_columns + col = idx % max_columns + x_position = col * window_width + y_position = row * window_height + + # 移动窗口到指定位置 + cv2.moveWindow(key, x_position, y_position) + + # 等待 1 毫秒以处理事件 + cv2.waitKey(1) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + # log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["record_start"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + events["record_start"] = False + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + elif key == keyboard.Key.up: + print("Up arrow pressed. Start data recording...") + events["record_start"] = True + + + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def stop_recording(robot, listener, display_cameras): + + if not is_headless(): + if listener is not None: + listener.stop() + + if display_cameras: + cv2.destroyAllWindows() + + +def record_episode( + robot, + dataset, + events, + episode_time_s, + display_cameras, + policy, + fps, + single_task, +): + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_cameras=display_cameras, + dataset=dataset, + events=events, + policy=policy, + fps=fps, + teleoperate=policy is None, + single_task=single_task, + ) + + +def record( + robot, + cfg +) -> LeRobotDataset: + # TODO(rcadene): Add option to record logs + if cfg.resume: + dataset = LeRobotDataset( + cfg.repo_id, + root=cfg.root, + ) + if len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.num_image_writer_processes, + num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + # sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) + else: + # Create empty dataset or load existing saved episodes + # sanity_check_dataset_name(cfg.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.root, + robot=None, + features=robot.features, + use_videos=cfg.video, + image_writer_processes=cfg.num_image_writer_processes, + image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + # policy = None + + # if not robot.is_connected: + # robot.connect() + + listener, events = init_keyboard_listener() + + # Execute a few seconds without recording to: + # 1. teleoperate the robot to move it in starting position if no policy provided, + # 2. give times to the robot devices to connect and start synchronizing, + # 3. place the cameras windows on screen + enable_teleoperation = policy is None + log_say("Warmup record", cfg.play_sounds) + print() + print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n") + # warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) + + # if has_method(robot, "teleop_safety_stop"): + # robot.teleop_safety_stop() + + recorded_episodes = 0 + while True: + if recorded_episodes >= cfg.num_episodes: + break + + # if events["record_start"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}") + record_episode( + robot=robot, + dataset=dataset, + events=events, + episode_time_s=cfg.episode_time_s, + display_cameras=cfg.display_cameras, + policy=policy, + fps=cfg.fps, + single_task=cfg.single_task, + ) + + # Execute a few seconds without recording to give time to manually reset the environment + # Current code logic doesn't allow to teleoperate during this time. + # TODO(rcadene): add an option to enable teleoperation during reset + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + pprint("Reset the environment, stop recording") + # reset_environment(robot, events, cfg.reset_time_s, cfg.fps) + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + pprint("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 + + if events["stop_recording"]: + break + + log_say("Stop recording", cfg.play_sounds, blocking=True) + stop_recording(robot, listener, cfg.display_cameras) + + if cfg.push_to_hub: + dataset.push_to_hub(tags=cfg.tags, private=cfg.private) + + log_say("Exiting", cfg.play_sounds) + return dataset + + +def replay( + robot: AgilexRobot, + cfg, +): + # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset + # TODO(rcadene): Add option to record logs + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) + actions = dataset.hf_dataset.select_columns("action") + + # if not robot.is_connected: + # robot.connect() + + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / cfg.fps - dt_s) + + dt_s = time.perf_counter() - start_episode_t + # log_control_info(robot, dt_s, fps=cfg.fps) + + +import argparse +def get_arguments(): + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" + args.root = "./data5" + args.episode = 0 # replay episode + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 100 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "move the bottle from the right to the scale right" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy = None + # args.teleoprate = True + args.control_type = "record" + # args.control_type = "replay" + return args + + + + +# @parser.wrap() +def control_robot(cfg): + # 使用工厂模式创建机器人实例 + robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg) + + if cfg.control_type == "record": + record(robot, cfg) + elif cfg.control_type == "replay": + replay(robot, cfg) + + +if __name__ == "__main__": + cfg = get_arguments() + control_robot(cfg) + # control_robot() + # 使用工厂模式创建机器人实例 + # robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) + # print(robot.features.items()) + # print([key for key, ft in robot.features.items() if ft["dtype"] == "video"]) + # record(robot, cfg) + # capture = robot.capture_observation() + # import torch + # torch.save(capture, "test.pt") + # action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094, + # 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]], + # device='cpu') + # robot.send_action(action.squeeze(0)) + # print() \ No newline at end of file diff --git a/lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc b/lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc new file mode 100644 index 0000000..84750e2 Binary files /dev/null and b/lerobot_aloha/common/__pycache__/agilex_robot.cpython-310.pyc differ diff --git a/lerobot_aloha/common/__pycache__/robot_components.cpython-310.pyc b/lerobot_aloha/common/__pycache__/robot_components.cpython-310.pyc new file mode 100644 index 0000000..00c1081 Binary files /dev/null and b/lerobot_aloha/common/__pycache__/robot_components.cpython-310.pyc differ diff --git a/lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc b/lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc new file mode 100644 index 0000000..9f5fa93 Binary files /dev/null and b/lerobot_aloha/common/__pycache__/rosrobot.cpython-310.pyc differ diff --git a/lerobot_aloha/common/__pycache__/rosrobot_factory.cpython-310.pyc b/lerobot_aloha/common/__pycache__/rosrobot_factory.cpython-310.pyc new file mode 100644 index 0000000..d058a8e Binary files /dev/null and b/lerobot_aloha/common/__pycache__/rosrobot_factory.cpython-310.pyc differ diff --git a/collect_data/agilex_robot.py b/lerobot_aloha/common/agilex_robot.py similarity index 94% rename from collect_data/agilex_robot.py rename to lerobot_aloha/common/agilex_robot.py index 28a701e..0dfca78 100644 --- a/collect_data/agilex_robot.py +++ b/lerobot_aloha/common/agilex_robot.py @@ -1,17 +1,13 @@ -import yaml import cv2 import numpy as np import collections import dm_env import argparse from typing import Dict, List, Any, Optional -from collections import deque import rospy -from cv_bridge import CvBridge from std_msgs.msg import Header -from sensor_msgs.msg import Image, JointState -from nav_msgs.msg import Odometry -from rosrobot import Robot +from sensor_msgs.msg import JointState +from .rosrobot import Robot import torch import time @@ -40,9 +36,12 @@ class AgilexRobot(Robot): # print("can not get data from puppet topic") # return None - if len(self.sync_arm_queues['puppet_left']) == 0 or len(self.sync_arm_queues['puppet_right']) == 0: - print("can not get data from puppet topic") - return None + # 检查必要的机械臂数据是否可用 + required_arms = ['puppet_left', 'puppet_right'] + for arm_name in required_arms: + if arm_name not in self.sync_arm_queues or len(self.sync_arm_queues[arm_name]) == 0: + print(f"can not get data from {arm_name} topic") + return None # 计算最小时间戳 timestamps = [ @@ -330,12 +329,18 @@ class AgilexRobot(Robot): Returns: The actual action that was sent (may be clipped if safety checks are implemented) """ - # if not hasattr(self, 'puppet_arm_publishers'): - # # Initialize publishers on first call - # self._init_action_publishers() + # 默认速度和力矩值 + last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, + 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, + -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.03296661376953125] - last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] - last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, + 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, + 0.0, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, + -0.03296661376953125, -0.03296661376953125] # Convert tensor to numpy array if needed if isinstance(action, torch.Tensor): diff --git a/collect_data/rosrobot.py b/lerobot_aloha/common/robot_components.py similarity index 50% rename from collect_data/rosrobot.py rename to lerobot_aloha/common/robot_components.py index 3a2de20..7669f71 100644 --- a/collect_data/rosrobot.py +++ b/lerobot_aloha/common/robot_components.py @@ -8,29 +8,38 @@ from nav_msgs.msg import Odometry import argparse -class Robot: - def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None): +class RobotConfig: + """Configuration management for robot components""" + + def __init__(self, config_file: str): """ - 机器人基类,处理通用初始化逻辑 + Initialize robot configuration from YAML file + Args: - config_file: YAML配置文件路径 - args: 运行时参数 + config_file: Path to YAML configuration file """ - self._load_config(config_file) - self._merge_runtime_args(args) - self._init_components() - self._init_data_structures() - self.init_ros() - self.init_features() - self.warmup() - - def _load_config(self, config_file: str) -> None: - """加载YAML配置文件""" + self.config = self._load_yaml(config_file) + self._validate_config() + + def _load_yaml(self, config_file: str) -> Dict[str, Any]: + """Load configuration from YAML file""" with open(config_file, 'r') as f: - self.config = yaml.safe_load(f) - - def _merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None: - """合并运行时参数到配置""" + return yaml.safe_load(f) + + def _validate_config(self) -> None: + """Validate configuration completeness""" + required_sections = ['cameras', 'arm'] + for section in required_sections: + if section not in self.config: + raise ValueError(f"Missing required config section: {section}") + + def merge_runtime_args(self, args: Optional[argparse.Namespace]) -> None: + """ + Merge runtime arguments into configuration + + Args: + args: Runtime arguments from command line + """ if args is None: return @@ -47,217 +56,56 @@ class Robot: for key, value in runtime_params.items(): if value is not None: self.config[key] = value + + def get(self, key: str, default=None) -> Any: + """Get configuration value with default fallback""" + return self.config.get(key, default) - def _init_components(self) -> None: - """初始化核心组件""" + +class RosAdapter: + """Adapter for ROS communication""" + + def __init__(self, config: RobotConfig): + """ + Initialize ROS adapter + + Args: + config: Robot configuration + """ + self.config = config self.bridge = CvBridge() self.subscribers = {} self.publishers = {} - self._validate_config() - - def _validate_config(self) -> None: - """验证配置完整性""" - required_sections = ['cameras', 'arm'] - for section in required_sections: - if section not in self.config: - raise ValueError(f"Missing required config section: {section}") - - def _init_data_structures(self) -> None: - """初始化数据结构模板方法""" - # 相机数据 - self.cameras = self.config.get('cameras', {}) - self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras} - # 深度数据 - self.use_depth_image = self.config.get('use_depth_image', False) - if self.use_depth_image: - self.sync_depth_queues = { - name: deque(maxlen=2000) - for name, cam in self.cameras.items() - if 'depth_topic_name' in cam - } + def init_ros_node(self, node_name: str = None) -> None: + """Initialize ROS node""" + if node_name is None: + node_name = self.config.get('ros_node_name', 'generic_robot_node') + + rospy.init_node(node_name, anonymous=True) - # 机械臂数据 - self.arms = self.config.get('arm', {}) - if self.config.get('control_type', '') != 'record': - # 如果不是录制模式,则仅初始化从机械臂数据队列 - self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name} - else: - self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms} - - # 机器人基座数据 - self.use_robot_base = self.config.get('use_robot_base', False) - if self.use_robot_base: - self.sync_base_queue = deque(maxlen=2000) - - def init_ros(self) -> None: - """初始化ROS订阅的模板方法""" - rospy.init_node( - f"{self.config.get('ros_node_name', 'generic_robot_node')}", - anonymous=True + def create_subscriber(self, topic: str, msg_type, callback, queue_size: int = 1000, tcp_nodelay: bool = True): + """Create a ROS subscriber""" + subscriber = rospy.Subscriber( + topic, + msg_type, + callback, + queue_size=queue_size, + tcp_nodelay=tcp_nodelay ) + return subscriber - self._setup_camera_subscribers() - self._setup_arm_subscribers_publishers() - self._setup_base_subscriber() - self._log_ros_status() - - def init_features(self): - """ - 根据YAML配置自动生成features结构 - """ - self.features = {} + def create_publisher(self, topic: str, msg_type, queue_size: int = 10): + """Create a ROS publisher""" + publisher = rospy.Publisher( + topic, + msg_type, + queue_size=queue_size + ) + return publisher - # 初始化相机特征 - self._init_camera_features() - - # 初始化机械臂特征 - self._init_state_features() - - self._init_action_features() - - # 初始化基座特征(如果启用) - if self.use_robot_base: - self._init_base_features() - import pprint - pprint.pprint(self.features, indent=4) - - - def _init_camera_features(self): - """处理所有相机特征""" - for cam_name, cam_config in self.cameras.items(): - # 普通图像 - self.features[f"observation.images.{cam_name}"] = { - "dtype": "video" if self.config.get("video", False) else "image", - "shape": cam_config.get("rgb_shape", [480, 640, 3]), - "names": ["height", "width", "channel"], - # "video_info": { - # "video.fps": cam_config.get("fps", 30.0), - # "video.codec": cam_config.get("codec", "av1"), - # "video.pix_fmt": cam_config.get("pix_fmt", "yuv420p"), - # "video.is_depth_map": False, - # "has_audio": False - # } - } - - if self.config.get("use_depth_image", False): - self.features[f"observation.images.depth_{cam_name}"] = { - "dtype": "uint16", - "shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1), - "names": ["height", "width"], - } - - - def _init_state_features(self): - state = self.config.get('state', {}) - # 状态特征 - self.features["observation.state"] = { - "dtype": "float32", - "shape": (len(state.get('motors', "")),), - "names": {"motors": state.get('motors', "")} - } - - if self.config.get('velocity'): - velocity = self.config.get('velocity', "") - self.features["observation.velocity"] = { - "dtype": "float32", - "shape": (len(velocity.get('motors', "")),), - "names": {"motors": velocity.get('motors', "")} - } - - if self.config.get('effort'): - effort = self.config.get('effort', "") - self.features["observation.effort"] = { - "dtype": "float32", - "shape": (len(effort.get('motors', "")),), - "names": {"motors": effort.get('motors', "")} - } - - - - def _init_action_features(self): - action = self.config.get('action', {}) - # 状态特征 - self.features["action"] = { - "dtype": "float32", - "shape": (len(action.get('motors', "")),), - "names": {"motors": action.get('motors', "")} - } - - def _init_base_features(self): - """处理基座特征""" - self.features["observation.base_vel"] = { - "dtype": "float32", - "shape": (2,), - "names": ["linear_x", "angular_z"] - } - - - def _setup_camera_subscribers(self) -> None: - """设置相机订阅者""" - for cam_name, cam_config in self.cameras.items(): - if 'img_topic_name' in cam_config: - self.subscribers[f"camera_{cam_name}"] = rospy.Subscriber( - cam_config['img_topic_name'], - Image, - self._make_camera_callback(cam_name, is_depth=False), - queue_size=1000, - tcp_nodelay=True - ) - - if self.use_depth_image and 'depth_topic_name' in cam_config: - self.subscribers[f"depth_{cam_name}"] = rospy.Subscriber( - cam_config['depth_topic_name'], - Image, - self._make_camera_callback(cam_name, is_depth=True), - queue_size=1000, - tcp_nodelay=True - ) - - def _setup_arm_subscribers_publishers(self) -> None: - """设置机械臂订阅者""" - # 当为record模式时,主从机械臂都需要订阅 - # 否则只订阅从机械臂,但向主机械臂发布 - if self.config.get('control_type', '') == 'record': - for arm_name, arm_config in self.arms.items(): - if 'topic_name' in arm_config: - self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber( - arm_config['topic_name'], - JointState, - self._make_arm_callback(arm_name), - queue_size=1000, - tcp_nodelay=True - ) - else: - for arm_name, arm_config in self.arms.items(): - if 'puppet' in arm_name: - self.subscribers[f"arm_{arm_name}"] = rospy.Subscriber( - arm_config['topic_name'], - JointState, - self._make_arm_callback(arm_name), - queue_size=1000, - tcp_nodelay=True - ) - if 'master' in arm_name: - self.publishers[f"arm_{arm_name}"] = rospy.Publisher( - arm_config['topic_name'], - JointState, - queue_size=10 - ) - - def _setup_base_subscriber(self) -> None: - """设置基座订阅者""" - if self.use_robot_base and 'robot_base' in self.config: - self.subscribers['base'] = rospy.Subscriber( - self.config['robot_base']['topic_name'], - Odometry, - self.robot_base_callback, - queue_size=1000, - tcp_nodelay=True - ) - - def _log_ros_status(self) -> None: - """记录ROS状态""" + def log_status(self) -> None: + """Log ROS connection status""" rospy.loginfo("\n=== ROS订阅状态 ===") rospy.loginfo(f"已初始化节点: {rospy.get_name()}") rospy.loginfo("活跃的订阅者:") @@ -265,8 +113,74 @@ class Robot: rospy.loginfo(f" - {topic}: {'活跃' if sub.impl else '未连接'}") rospy.loginfo("=================") + +class RobotSensors: + """Management of robot sensors (cameras, depth sensors)""" + + def __init__(self, config: RobotConfig, ros_adapter: RosAdapter): + """ + Initialize robot sensors + + Args: + config: Robot configuration + ros_adapter: ROS communication adapter + """ + self.config = config + self.ros_adapter = ros_adapter + self.bridge = ros_adapter.bridge + + # Camera data + self.cameras = config.get('cameras', {}) + self.sync_img_queues = {name: deque(maxlen=2000) for name in self.cameras} + + # Depth data + self.use_depth_image = config.get('use_depth_image', False) + if self.use_depth_image: + self.sync_depth_queues = { + name: deque(maxlen=2000) + for name, cam in self.cameras.items() + if 'depth_topic_name' in cam + } + + # Robot base data + self.use_robot_base = config.get('use_robot_base', False) + if self.use_robot_base: + self.sync_base_queue = deque(maxlen=2000) + + def setup_subscribers(self) -> None: + """Set up ROS subscribers for sensors""" + self._setup_camera_subscribers() + if self.use_robot_base: + self._setup_base_subscriber() + + def _setup_camera_subscribers(self) -> None: + """Set up camera subscribers""" + for cam_name, cam_config in self.cameras.items(): + if 'img_topic_name' in cam_config: + self.ros_adapter.subscribers[f"camera_{cam_name}"] = self.ros_adapter.create_subscriber( + cam_config['img_topic_name'], + Image, + self._make_camera_callback(cam_name, is_depth=False) + ) + + if self.use_depth_image and 'depth_topic_name' in cam_config: + self.ros_adapter.subscribers[f"depth_{cam_name}"] = self.ros_adapter.create_subscriber( + cam_config['depth_topic_name'], + Image, + self._make_camera_callback(cam_name, is_depth=True) + ) + + def _setup_base_subscriber(self) -> None: + """Set up base subscriber""" + if 'robot_base' in self.config.config: + self.ros_adapter.subscribers['base'] = self.ros_adapter.create_subscriber( + self.config.get('robot_base')['topic_name'], + Odometry, + self.robot_base_callback + ) + def _make_camera_callback(self, cam_name: str, is_depth: bool = False): - """生成相机回调函数工厂方法""" + """Generate camera callback factory method""" def callback(msg): try: target_queue = ( @@ -281,8 +195,105 @@ class Robot: rospy.logerr(f"Camera {cam_name} callback error: {str(e)}") return callback + def robot_base_callback(self, msg): + """Base callback default implementation""" + if len(self.sync_base_queue) >= 2000: + self.sync_base_queue.popleft() + self.sync_base_queue.append(msg) + + def init_features(self) -> Dict[str, Any]: + """Initialize sensor features""" + features = {} + + # Initialize camera features + self._init_camera_features(features) + + # Initialize base features (if enabled) + if self.use_robot_base: + self._init_base_features(features) + + return features + + def _init_camera_features(self, features: Dict[str, Any]) -> None: + """Process all camera features""" + for cam_name, cam_config in self.cameras.items(): + # Regular images + features[f"observation.images.{cam_name}"] = { + "dtype": "video" if self.config.get("video", False) else "image", + "shape": cam_config.get("rgb_shape", [480, 640, 3]), + "names": ["height", "width", "channel"], + } + + if self.config.get("use_depth_image", False): + features[f"observation.images.depth_{cam_name}"] = { + "dtype": "uint16", + "shape": (cam_config.get("width", 480), cam_config.get("height", 640), 1), + "names": ["height", "width"], + } + + def _init_base_features(self, features: Dict[str, Any]) -> None: + """Process base features""" + features["observation.base_vel"] = { + "dtype": "float32", + "shape": (2,), + "names": ["linear_x", "angular_z"] + } + + +class RobotActuators: + """Management of robot actuators (arms, base)""" + + def __init__(self, config: RobotConfig, ros_adapter: RosAdapter): + """ + Initialize robot actuators + + Args: + config: Robot configuration + ros_adapter: ROS communication adapter + """ + self.config = config + self.ros_adapter = ros_adapter + + # Arm data + self.arms = config.get('arm', {}) + if config.get('control_type', '') != 'record': + # If not in record mode, only initialize puppet arm queues + self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms if 'puppet' in name} + else: + self.sync_arm_queues = {name: deque(maxlen=2000) for name in self.arms} + + def setup_subscribers_publishers(self) -> None: + """Set up ROS subscribers and publishers for actuators""" + self._setup_arm_subscribers_publishers() + + def _setup_arm_subscribers_publishers(self) -> None: + """Set up arm subscribers and publishers""" + # When in record mode, subscribe to both master and puppet arms + # Otherwise only subscribe to puppet arms, but publish to master arms + if self.config.get('control_type', '') == 'record': + for arm_name, arm_config in self.arms.items(): + if 'topic_name' in arm_config: + self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber( + arm_config['topic_name'], + JointState, + self._make_arm_callback(arm_name) + ) + else: + for arm_name, arm_config in self.arms.items(): + if 'puppet' in arm_name: + self.ros_adapter.subscribers[f"arm_{arm_name}"] = self.ros_adapter.create_subscriber( + arm_config['topic_name'], + JointState, + self._make_arm_callback(arm_name) + ) + if 'master' in arm_name: + self.ros_adapter.publishers[f"arm_{arm_name}"] = self.ros_adapter.create_publisher( + arm_config['topic_name'], + JointState + ) + def _make_arm_callback(self, arm_name: str): - """生成机械臂回调函数工厂方法""" + """Generate arm callback factory method""" def callback(msg): try: if len(self.sync_arm_queues[arm_name]) >= 2000: @@ -292,17 +303,74 @@ class Robot: rospy.logerr(f"Arm {arm_name} callback error: {str(e)}") return callback - def robot_base_callback(self, msg): - """基座回调默认实现""" - if len(self.sync_base_queue) >= 2000: - self.sync_base_queue.popleft() - self.sync_base_queue.append(msg) + def init_features(self) -> Dict[str, Any]: + """Initialize actuator features""" + features = {} + + # Initialize arm features + self._init_state_features(features) + self._init_action_features(features) + + return features + + def _init_state_features(self, features: Dict[str, Any]) -> None: + """Initialize state features""" + state = self.config.get('state', {}) + # State features + features["observation.state"] = { + "dtype": "float32", + "shape": (len(state.get('motors', "")),), + "names": {"motors": state.get('motors', "")} + } - def warmup(self, timeout: float = 10.0) -> bool: - """Wait until all data queues have at least 20 messages. + if self.config.get('velocity'): + velocity = self.config.get('velocity', "") + features["observation.velocity"] = { + "dtype": "float32", + "shape": (len(velocity.get('motors', "")),), + "names": {"motors": velocity.get('motors', "")} + } + + if self.config.get('effort'): + effort = self.config.get('effort', "") + features["observation.effort"] = { + "dtype": "float32", + "shape": (len(effort.get('motors', "")),), + "names": {"motors": effort.get('motors', "")} + } + + def _init_action_features(self, features: Dict[str, Any]) -> None: + """Initialize action features""" + action = self.config.get('action', {}) + features["action"] = { + "dtype": "float32", + "shape": (len(action.get('motors', "")),), + "names": {"motors": action.get('motors', "")} + } + + +class RobotDataManager: + """Management of robot data collection and synchronization""" + + def __init__(self, config: RobotConfig, sensors: RobotSensors, actuators: RobotActuators): + """ + Initialize robot data manager Args: - timeout: Maximum time to wait in seconds before giving up + config: Robot configuration + sensors: Robot sensors component + actuators: Robot actuators component + """ + self.config = config + self.sensors = sensors + self.actuators = actuators + + def warmup(self, timeout: float = 10.0) -> bool: + """ + Wait until all data queues have sufficient messages + + Args: + timeout: Maximum time to wait in seconds Returns: bool: True if warmup succeeded, False if timed out @@ -323,31 +391,24 @@ class Robot: all_ready = True # Check camera image queues - for cam_name in self.cameras: - if len(self.sync_img_queues[cam_name]) < 50: - rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sync_img_queues[cam_name])}/50)") + for cam_name in self.sensors.cameras: + if len(self.sensors.sync_img_queues[cam_name]) < 50: + rospy.loginfo(f"Waiting for camera {cam_name} (current: {len(self.sensors.sync_img_queues[cam_name])}/50)") all_ready = False break # Check depth queues if enabled - if self.use_depth_image: - for cam_name in self.sync_depth_queues: - if len(self.sync_depth_queues[cam_name]) < 50: - rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sync_depth_queues[cam_name])}/50)") + if self.sensors.use_depth_image: + for cam_name in self.sensors.sync_depth_queues: + if len(self.sensors.sync_depth_queues[cam_name]) < 50: + rospy.loginfo(f"Waiting for depth camera {cam_name} (current: {len(self.sensors.sync_depth_queues[cam_name])}/50)") all_ready = False break - # # Check arm queues - # for arm_name in self.arms: - # if len(self.sync_arm_queues[arm_name]) < 20: - # rospy.loginfo(f"Waiting for arm {arm_name} (current: {len(self.sync_arm_queues[arm_name])}/20)") - # all_ready = False - # break - # Check base queue if enabled - if self.use_robot_base: - if len(self.sync_base_queue) < 20: - rospy.loginfo(f"Waiting for base (current: {len(self.sync_base_queue)}/20)") + if self.sensors.use_robot_base: + if len(self.sensors.sync_base_queue) < 20: + rospy.loginfo(f"Waiting for base (current: {len(self.sensors.sync_base_queue)}/20)") all_ready = False # If all queues are ready, return success @@ -357,16 +418,4 @@ class Robot: rate.sleep() - return False - - - - - - def get_frame(self) -> Optional[Dict[str, Any]]: - """获取同步帧数据的模板方法""" - raise NotImplementedError("Subclasses must implement get_frame()") - - def process(self) -> tuple: - """主处理循环的模板方法""" - raise NotImplementedError("Subclasses must implement process()") + return False \ No newline at end of file diff --git a/lerobot_aloha/common/rosrobot.py b/lerobot_aloha/common/rosrobot.py new file mode 100644 index 0000000..30f80dd --- /dev/null +++ b/lerobot_aloha/common/rosrobot.py @@ -0,0 +1,136 @@ +import yaml +from typing import Dict, Any, Optional, List +import argparse +from .robot_components import RobotConfig, RosAdapter, RobotSensors, RobotActuators, RobotDataManager + + +class Robot: + def __init__(self, config_file: str, args: Optional[argparse.Namespace] = None): + """ + 机器人基类,处理通用初始化逻辑 + Args: + config_file: YAML配置文件路径 + args: 运行时参数 + """ + # 初始化组件 + self.config = RobotConfig(config_file) + self.config.merge_runtime_args(args) + self.ros_adapter = RosAdapter(self.config) + self.sensors = RobotSensors(self.config, self.ros_adapter) + self.actuators = RobotActuators(self.config, self.ros_adapter) + self.data_manager = RobotDataManager(self.config, self.sensors, self.actuators) + + # 初始化ROS和特征 + self.init_ros() + self.init_features() + self.warmup() + + def get(self, key: str, default=None) -> Any: + """获取配置值""" + return self.config.get(key, default) + + @property + def bridge(self): + """获取CV桥接器""" + return self.ros_adapter.bridge + + @property + def subscribers(self): + """获取订阅者""" + return self.ros_adapter.subscribers + + @property + def publishers(self): + """获取发布者""" + return self.ros_adapter.publishers + + @property + def cameras(self): + """获取相机配置""" + return self.sensors.cameras + + @property + def arms(self): + """获取机械臂配置""" + return self.actuators.arms + + @property + def sync_img_queues(self): + """获取图像队列""" + return self.sensors.sync_img_queues + + @property + def sync_depth_queues(self): + """获取深度图像队列""" + return self.sensors.sync_depth_queues if hasattr(self.sensors, 'sync_depth_queues') else {} + + @property + def sync_arm_queues(self): + """获取机械臂队列""" + return self.actuators.sync_arm_queues + + @property + def sync_base_queue(self): + """获取基座队列""" + return self.sensors.sync_base_queue if hasattr(self.sensors, 'sync_base_queue') else None + + @property + def use_depth_image(self): + """是否使用深度图像""" + return self.sensors.use_depth_image + + @property + def use_robot_base(self): + """是否使用机器人基座""" + return self.sensors.use_robot_base + + def init_ros(self) -> None: + """初始化ROS订阅的模板方法""" + self.ros_adapter.init_ros_node() + + # 设置传感器和执行器的订阅者和发布者 + self.sensors.setup_subscribers() + self.actuators.setup_subscribers_publishers() + + # 记录ROS状态 + self.ros_adapter.log_status() + + def init_features(self): + """ + 根据YAML配置自动生成features结构 + """ + # 合并传感器和执行器的特征 + self.features = {} + self.features.update(self.sensors.init_features()) + self.features.update(self.actuators.init_features()) + + import pprint + pprint.pprint(self.features, indent=4) + + + + + + + def warmup(self, timeout: float = 10.0) -> bool: + """Wait until all data queues have at least 20 messages. + + Args: + timeout: Maximum time to wait in seconds before giving up + + Returns: + bool: True if warmup succeeded, False if timed out + """ + return self.data_manager.warmup(timeout) + + + + + + def get_frame(self) -> Optional[Dict[str, Any]]: + """获取同步帧数据的模板方法""" + raise NotImplementedError("Subclasses must implement get_frame()") + + def process(self) -> tuple: + """主处理循环的模板方法""" + raise NotImplementedError("Subclasses must implement process()") diff --git a/lerobot_aloha/common/rosrobot_factory.py b/lerobot_aloha/common/rosrobot_factory.py new file mode 100644 index 0000000..9409586 --- /dev/null +++ b/lerobot_aloha/common/rosrobot_factory.py @@ -0,0 +1,59 @@ +import yaml +import argparse +from typing import Dict, List, Any, Optional, Type +from .rosrobot import Robot +from .agilex_robot import AgilexRobot + + +class RobotFactory: + """Factory for creating robot instances based on configuration""" + + # 注册表,用于存储可用的机器人类型 + _registry = {} + + @classmethod + def register(cls, robot_type: str, robot_class: Type[Robot]) -> None: + """ + 注册新的机器人类型 + + Args: + robot_type: 机器人类型标识符 + robot_class: 机器人类实现 + """ + cls._registry[robot_type] = robot_class + + @classmethod + def create(cls, config_file: str, args: Optional[argparse.Namespace] = None) -> Robot: + """ + 根据配置文件自动创建合适的机器人实例 + + Args: + config_file: 配置文件路径 + args: 运行时参数 + + Returns: + Robot: 创建的机器人实例 + + Raises: + ValueError: 如果指定的机器人类型不受支持 + """ + with open(config_file, 'r') as f: + config = yaml.safe_load(f) + + robot_type = config.get('robot_type', 'agilex') + + # 如果注册表为空,注册默认机器人类型 + if not cls._registry: + cls.register('agilex', AgilexRobot) + cls.register('aloha_agilex', AgilexRobot) # 别名支持 + + # 从注册表中查找机器人类 + if robot_type in cls._registry: + return cls._registry[robot_type](config_file, args) + else: + raise ValueError(f"Unsupported robot type: {robot_type}. Available types: {list(cls._registry.keys())}") + + +# 注册可用的机器人类型 +RobotFactory.register('agilex', AgilexRobot) +RobotFactory.register('aloha_agilex', AgilexRobot) # 别名支持 diff --git a/lerobot_aloha/configs/agilex.yaml b/lerobot_aloha/configs/agilex.yaml new file mode 100644 index 0000000..703b7e2 --- /dev/null +++ b/lerobot_aloha/configs/agilex.yaml @@ -0,0 +1,146 @@ +robot_type: aloha_agilex +ros_node_name: record_episodes +cameras: + cam_front: + img_topic_name: /camera_f/color/image_raw + depth_topic_name: /camera_f/depth/image_raw + width: 480 + height: 640 + rgb_shape: [480, 640, 3] + cam_left: + img_topic_name: /camera_l/color/image_raw + depth_topic_name: /camera_l/depth/image_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + cam_right: + img_topic_name: /camera_r/color/image_raw + depth_topic_name: /camera_r/depth/image_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + cam_high: + img_topic_name: /camera/color/image_raw + depth_topic_name: /camera/depth/image_rect_raw + rgb_shape: [480, 640, 3] + width: 480 + height: 640 + +arm: + master_left: + topic_name: /master/joint_left + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none" + ] + master_right: + topic_name: /master/joint_right + motors: [ + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + puppet_left: + topic_name: /puppet/joint_left + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none" + ] + puppet_right: + topic_name: /puppet/joint_right + motors: [ + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +# follow the joint name in ros +state: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +velocity: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +effort: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] + +action: + motors: [ + "left_joint0", + "left_joint1", + "left_joint2", + "left_joint3", + "left_joint4", + "left_joint5", + "left_none", + "right_joint0", + "right_joint1", + "right_joint2", + "right_joint3", + "right_joint4", + "right_joint5", + "right_none" + ] diff --git a/lerobot_aloha/inference.py b/lerobot_aloha/inference.py new file mode 100644 index 0000000..34f7f52 --- /dev/null +++ b/lerobot_aloha/inference.py @@ -0,0 +1,769 @@ +#!/home/lin/software/miniconda3/envs/aloha/bin/python +# -- coding: UTF-8 +""" +#!/usr/bin/python3 +""" + +import torch +import numpy as np +import os +import pickle +import argparse +from einops import rearrange +import collections +from collections import deque + +import rospy +from std_msgs.msg import Header +from geometry_msgs.msg import Twist +from sensor_msgs.msg import JointState, Image +from nav_msgs.msg import Odometry +from cv_bridge import CvBridge +import time +import threading +import math +import threading + + + + +import sys +sys.path.append("./") + +SEED = 42 +torch.manual_seed(SEED) +np.random.seed(SEED) + +task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']} + +inference_thread = None +inference_lock = threading.Lock() +inference_actions = None +inference_timestep = None + + +def actions_interpolation(args, pre_action, actions, stats): + steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0) + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + result = [pre_action] + post_action = post_process(actions[0]) + # print("pre_action:", pre_action[7:]) + # print("actions_interpolation1:", post_action[:, 7:]) + max_diff_index = 0 + max_diff = -1 + for i in range(post_action.shape[0]): + diff = 0 + for j in range(pre_action.shape[0]): + if j == 6 or j == 13: + continue + diff += math.fabs(pre_action[j] - post_action[i][j]) + if diff > max_diff: + max_diff = diff + max_diff_index = i + + for i in range(max_diff_index, post_action.shape[0]): + step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])]) + inter = np.linspace(result[-1], post_action[i], step+2) + result.extend(inter[1:]) + while len(result) < args.chunk_size+1: + result.append(result[-1]) + result = np.array(result)[1:args.chunk_size+1] + # print("actions_interpolation2:", result.shape, result[:, 7:]) + result = pre_process(result) + result = result[np.newaxis, :] + return result + + +def get_model_config(args): + # 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。 + set_seed(1) + + # 如果是ACT策略 + # fixed parameters + if args.policy_class == 'ACT': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': args.chunk_size, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base, + 'kl_weight': args.kl_weight, # kl散度权重 + 'hidden_dim': args.hidden_dim, # 隐藏层维度 + 'dim_feedforward': args.dim_feedforward, + 'enc_layers': args.enc_layers, + 'dec_layers': args.dec_layers, + 'nheads': args.nheads, + 'dropout': args.dropout, + 'pre_norm': args.pre_norm + } + elif args.policy_class == 'CNNMLP': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': 1, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base + } + + elif args.policy_class == 'Diffusion': + policy_config = {'lr': args.lr, + 'lr_backbone': args.lr_backbone, + 'backbone': args.backbone, + 'masks': args.masks, + 'weight_decay': args.weight_decay, + 'dilation': args.dilation, + 'position_embedding': args.position_embedding, + 'loss_function': args.loss_function, + 'chunk_size': args.chunk_size, # 查询 + 'camera_names': task_config['camera_names'], + 'use_depth_image': args.use_depth_image, + 'use_robot_base': args.use_robot_base, + 'observation_horizon': args.observation_horizon, + 'action_horizon': args.action_horizon, + 'num_inference_timesteps': args.num_inference_timesteps, + 'ema_power': args.ema_power + } + else: + raise NotImplementedError + + config = { + 'ckpt_dir': args.ckpt_dir, + 'ckpt_name': args.ckpt_name, + 'ckpt_stats_name': args.ckpt_stats_name, + 'episode_len': args.max_publish_step, + 'state_dim': args.state_dim, + 'policy_class': args.policy_class, + 'policy_config': policy_config, + 'temporal_agg': args.temporal_agg, + 'camera_names': task_config['camera_names'], + } + return config + + +def make_policy(policy_class, policy_config): + if policy_class == 'ACT': + policy = ACTPolicy(policy_config) + elif policy_class == 'CNNMLP': + policy = CNNMLPPolicy(policy_config) + elif policy_class == 'Diffusion': + policy = DiffusionPolicy(policy_config) + else: + raise NotImplementedError + return policy + + +def get_image(observation, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w') + + curr_images.append(curr_image) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def get_depth_image(observation, camera_names): + curr_images = [] + for cam_name in camera_names: + curr_images.append(observation['images_depth'][cam_name]) + curr_image = np.stack(curr_images, axis=0) + curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) + return curr_image + + +def inference_process(args, config, ros_operator, policy, stats, t, pre_action): + global inference_lock + global inference_actions + global inference_timestep + print_flag = True + pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"] + rate = rospy.Rate(args.publish_rate) + while True and not rospy.is_shutdown(): + result = ros_operator.get_frame() + if not result: + if print_flag: + print("syn fail") + print_flag = False + rate.sleep() + continue + print_flag = True + (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, robot_base) = result + obs = collections.OrderedDict() + image_dict = dict() + + image_dict[config['camera_names'][0]] = img_front + image_dict[config['camera_names'][1]] = img_left + image_dict[config['camera_names'][2]] = img_right + + + obs['images'] = image_dict + + if args.use_depth_image: + image_depth_dict = dict() + image_depth_dict[config['camera_names'][0]] = img_front_depth + image_depth_dict[config['camera_names'][1]] = img_left_depth + image_depth_dict[config['camera_names'][2]] = img_right_depth + obs['images_depth'] = image_depth_dict + + obs['qpos'] = np.concatenate( + (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0) + obs['qvel'] = np.concatenate( + (np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0) + obs['effort'] = np.concatenate( + (np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0) + if args.use_robot_base: + obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z] + obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0) + else: + obs['base_vel'] = [0.0, 0.0] + # qpos_numpy = np.array(obs['qpos']) + + # 归一化处理qpos 并转到cuda + qpos = pre_pos_process(obs['qpos']) + qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) + # 当前图像curr_image获取图像 + curr_image = get_image(obs, config['camera_names']) + curr_depth_image = None + if args.use_depth_image: + curr_depth_image = get_depth_image(obs, config['camera_names']) + start_time = time.time() + all_actions = policy(curr_image, curr_depth_image, qpos) + end_time = time.time() + print("model cost time: ", end_time -start_time) + inference_lock.acquire() + inference_actions = all_actions.cpu().detach().numpy() + if pre_action is None: + pre_action = obs['qpos'] + # print("obs['qpos']:", obs['qpos'][7:]) + if args.use_actions_interpolation: + inference_actions = actions_interpolation(args, pre_action, inference_actions, stats) + inference_timestep = t + inference_lock.release() + break + + +def model_inference(args, config, ros_operator, save_episode=True): + global inference_lock + global inference_actions + global inference_timestep + global inference_thread + set_seed(1000) + + # 1 创建模型数据 继承nn.Module + policy = make_policy(config['policy_class'], config['policy_config']) + # print("model structure\n", policy.model) + + # 2 加载模型权重 + ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name']) + state_dict = torch.load(ckpt_path) + new_state_dict = {} + for key, value in state_dict.items(): + if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]: + continue + if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]: + continue + new_state_dict[key] = value + loading_status = policy.deserialize(new_state_dict) + if not loading_status: + print("ckpt path not exist") + return False + + # 3 模型设置为cuda模式和验证模式 + policy.cuda() + policy.eval() + + # 4 加载统计值 + stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name']) + # 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维 + with open(stats_path, 'rb') as f: + stats = pickle.load(f) + + # 数据预处理和后处理函数定义 + pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] + post_process = lambda a: a * stats['action_std'] + stats['action_mean'] + + max_publish_step = config['episode_len'] + chunk_size = config['policy_config']['chunk_size'] + + # 发布基础的姿态 + left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875] + right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875] + left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] + right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883] + + ros_operator.puppet_arm_publish_continuous(left0, right0) + input("Enter any key to continue :") + ros_operator.puppet_arm_publish_continuous(left1, right1) + action = None + # 推理 + with torch.inference_mode(): + while True and not rospy.is_shutdown(): + # 每个回合的步数 + t = 0 + max_t = 0 + rate = rospy.Rate(args.publish_rate) + if config['temporal_agg']: + all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']]) + while t < max_publish_step and not rospy.is_shutdown(): + # start_time = time.time() + # query policy + if config['policy_class'] == "ACT": + if t >= max_t: + pre_action = action + inference_thread = threading.Thread(target=inference_process, + args=(args, config, ros_operator, + policy, stats, t, pre_action)) + inference_thread.start() + inference_thread.join() + inference_lock.acquire() + if inference_actions is not None: + inference_thread = None + all_actions = inference_actions + inference_actions = None + max_t = t + args.pos_lookahead_step + if config['temporal_agg']: + all_time_actions[[t], t:t + chunk_size] = all_actions + inference_lock.release() + if config['temporal_agg']: + actions_for_curr_step = all_time_actions[:, t] + actions_populated = np.all(actions_for_curr_step != 0, axis=1) + actions_for_curr_step = actions_for_curr_step[actions_populated] + k = 0.01 + exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) + exp_weights = exp_weights / exp_weights.sum() + exp_weights = exp_weights[:, np.newaxis] + raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True) + else: + if args.pos_lookahead_step != 0: + raw_action = all_actions[:, t % args.pos_lookahead_step] + else: + raw_action = all_actions[:, t % chunk_size] + else: + raise NotImplementedError + action = post_process(raw_action[0]) + left_action = action[:7] # 取7维度 + right_action = action[7:14] + ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread + if args.use_robot_base: + vel_action = action[14:16] + ros_operator.robot_base_publish(vel_action) + t += 1 + # end_time = time.time() + # print("publish: ", t) + # print("time:", end_time - start_time) + # print("left_action:", left_action) + # print("right_action:", right_action) + rate.sleep() + + +class RosOperator: + def __init__(self, args): + self.robot_base_deque = None + self.puppet_arm_right_deque = None + self.puppet_arm_left_deque = None + self.img_front_deque = None + self.img_right_deque = None + self.img_left_deque = None + self.img_front_depth_deque = None + self.img_right_depth_deque = None + self.img_left_depth_deque = None + self.bridge = None + self.puppet_arm_left_publisher = None + self.puppet_arm_right_publisher = None + self.robot_base_publisher = None + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_lock = None + self.args = args + self.ctrl_state = False + self.ctrl_state_lock = threading.Lock() + self.init() + self.init_ros() + + def init(self): + self.bridge = CvBridge() + self.img_left_deque = deque() + self.img_right_deque = deque() + self.img_front_deque = deque() + self.img_left_depth_deque = deque() + self.img_right_depth_deque = deque() + self.img_front_depth_deque = deque() + self.puppet_arm_left_deque = deque() + self.puppet_arm_right_deque = deque() + self.robot_base_deque = deque() + self.puppet_arm_publish_lock = threading.Lock() + self.puppet_arm_publish_lock.acquire() + + def puppet_arm_publish(self, left, right): + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right + self.puppet_arm_right_publisher.publish(joint_state_msg) + + def robot_base_publish(self, vel): + vel_msg = Twist() + vel_msg.linear.x = vel[0] + vel_msg.linear.y = 0 + vel_msg.linear.z = 0 + vel_msg.angular.x = 0 + vel_msg.angular.y = 0 + vel_msg.angular.z = vel[1] + self.robot_base_publisher.publish(vel_msg) + + def puppet_arm_publish_continuous(self, left, right): + rate = rospy.Rate(self.args.publish_rate) + left_arm = None + right_arm = None + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))] + right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))] + flag = True + step = 0 + while flag and not rospy.is_shutdown(): + if self.puppet_arm_publish_lock.acquire(False): + return + left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))] + right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))] + flag = False + for i in range(len(left)): + if left_diff[i] < self.args.arm_steps_length[i]: + left_arm[i] = left[i] + else: + left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] + flag = True + for i in range(len(right)): + if right_diff[i] < self.args.arm_steps_length[i]: + right_arm[i] = right[i] + else: + right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i] + flag = True + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = left_arm + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = right_arm + self.puppet_arm_right_publisher.publish(joint_state_msg) + step += 1 + print("puppet_arm_publish_continuous:", step) + rate.sleep() + + def puppet_arm_publish_linear(self, left, right): + num_step = 100 + rate = rospy.Rate(200) + + left_arm = None + right_arm = None + + while True and not rospy.is_shutdown(): + if len(self.puppet_arm_left_deque) != 0: + left_arm = list(self.puppet_arm_left_deque[-1].position) + if len(self.puppet_arm_right_deque) != 0: + right_arm = list(self.puppet_arm_right_deque[-1].position) + if left_arm is None or right_arm is None: + rate.sleep() + continue + else: + break + + traj_left_list = np.linspace(left_arm, left, num_step) + traj_right_list = np.linspace(right_arm, right, num_step) + + for i in range(len(traj_left_list)): + traj_left = traj_left_list[i] + traj_right = traj_right_list[i] + traj_left[-1] = left[-1] + traj_right[-1] = right[-1] + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 + joint_state_msg.position = traj_left + self.puppet_arm_left_publisher.publish(joint_state_msg) + joint_state_msg.position = traj_right + self.puppet_arm_right_publisher.publish(joint_state_msg) + rate.sleep() + + def puppet_arm_publish_continuous_thread(self, left, right): + if self.puppet_arm_publish_thread is not None: + self.puppet_arm_publish_lock.release() + self.puppet_arm_publish_thread.join() + self.puppet_arm_publish_lock.acquire(False) + self.puppet_arm_publish_thread = None + self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right)) + self.puppet_arm_publish_thread.start() + + def get_frame(self): + if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \ + (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)): + return False + if self.args.use_depth_image: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(), + self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()]) + else: + frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()]) + + if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time: + return False + if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time: + return False + if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): + return False + if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): + return False + + while self.img_left_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_deque.popleft() + img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough') + + while self.img_right_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_deque.popleft() + img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough') + + while self.img_front_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_deque.popleft() + img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough') + + while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_left_deque.popleft() + puppet_arm_left = self.puppet_arm_left_deque.popleft() + + while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: + self.puppet_arm_right_deque.popleft() + puppet_arm_right = self.puppet_arm_right_deque.popleft() + + img_left_depth = None + if self.args.use_depth_image: + while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_left_depth_deque.popleft() + img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough') + + img_right_depth = None + if self.args.use_depth_image: + while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_right_depth_deque.popleft() + img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough') + + img_front_depth = None + if self.args.use_depth_image: + while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: + self.img_front_depth_deque.popleft() + img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough') + + robot_base = None + if self.args.use_robot_base: + while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: + self.robot_base_deque.popleft() + robot_base = self.robot_base_deque.popleft() + + return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, + puppet_arm_left, puppet_arm_right, robot_base) + + def img_left_callback(self, msg): + if len(self.img_left_deque) >= 2000: + self.img_left_deque.popleft() + self.img_left_deque.append(msg) + + def img_right_callback(self, msg): + if len(self.img_right_deque) >= 2000: + self.img_right_deque.popleft() + self.img_right_deque.append(msg) + + def img_front_callback(self, msg): + if len(self.img_front_deque) >= 2000: + self.img_front_deque.popleft() + self.img_front_deque.append(msg) + + def img_left_depth_callback(self, msg): + if len(self.img_left_depth_deque) >= 2000: + self.img_left_depth_deque.popleft() + self.img_left_depth_deque.append(msg) + + def img_right_depth_callback(self, msg): + if len(self.img_right_depth_deque) >= 2000: + self.img_right_depth_deque.popleft() + self.img_right_depth_deque.append(msg) + + def img_front_depth_callback(self, msg): + if len(self.img_front_depth_deque) >= 2000: + self.img_front_depth_deque.popleft() + self.img_front_depth_deque.append(msg) + + def puppet_arm_left_callback(self, msg): + if len(self.puppet_arm_left_deque) >= 2000: + self.puppet_arm_left_deque.popleft() + self.puppet_arm_left_deque.append(msg) + + def puppet_arm_right_callback(self, msg): + if len(self.puppet_arm_right_deque) >= 2000: + self.puppet_arm_right_deque.popleft() + self.puppet_arm_right_deque.append(msg) + + def robot_base_callback(self, msg): + if len(self.robot_base_deque) >= 2000: + self.robot_base_deque.popleft() + self.robot_base_deque.append(msg) + + def ctrl_callback(self, msg): + self.ctrl_state_lock.acquire() + self.ctrl_state = msg.data + self.ctrl_state_lock.release() + + def get_ctrl_state(self): + self.ctrl_state_lock.acquire() + state = self.ctrl_state + self.ctrl_state_lock.release() + return state + + def init_ros(self): + rospy.init_node('joint_state_publisher', anonymous=True) + rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True) + if self.args.use_depth_image: + rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True) + rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True) + self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10) + self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10) + self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10) + + +def get_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) + parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False) + parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False) + parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False) + parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False) + parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False) + parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False) + parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False) + parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False) + parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False) + parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False) + parser.add_argument('--dilation', action='store_true', + help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False) + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features", required=False) + parser.add_argument('--masks', action='store_true', + help="Train segmentation head if the flag is provided") + parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False) + parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False) + parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False) + parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False) + + parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False) + parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False) + parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False) + parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False) + parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False) + parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False) + parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False) + parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False) + parser.add_argument('--pre_norm', action='store_true', required=False) + + parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', + default='/camera_f/color/image_raw', required=False) + parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', + default='/camera_l/color/image_raw', required=False) + parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', + default='/camera_r/color/image_raw', required=False) + + parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic', + default='/camera_f/depth/image_raw', required=False) + parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic', + default='/camera_l/depth/image_raw', required=False) + parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic', + default='/camera_r/depth/image_raw', required=False) + + parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic', + default='/master/joint_left', required=False) + parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic', + default='/master/joint_right', required=False) + parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', + default='/puppet/joint_left', required=False) + parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', + default='/puppet/joint_right', required=False) + + parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', + default='/odom_raw', required=False) + parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic', + default='/cmd_vel', required=False) + parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', + default=False, required=False) + parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate', + default=40, required=False) + parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step', + default=0, required=False) + parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', + default=32, required=False) + parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length', + default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False) + + parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation', + default=False, required=False) + parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image', + default=False, required=False) + + # for Diffusion + parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False) + parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False) + parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False) + parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False) + args = parser.parse_args() + return args + + +def main(): + args = get_arguments() + ros_operator = RosOperator(args) + config = get_model_config(args) + model_inference(args, config, ros_operator, save_episode=True) + + +if __name__ == '__main__': + main() +# python act/inference.py --ckpt_dir ~/train0314/ \ No newline at end of file diff --git a/lerobot_aloha/read_parquet.py b/lerobot_aloha/read_parquet.py new file mode 100644 index 0000000..577a1e3 --- /dev/null +++ b/lerobot_aloha/read_parquet.py @@ -0,0 +1,33 @@ +import pandas as pd + +def read_and_print_parquet_row(file_path, row_index=0): + """ + 读取Parquet文件并打印指定行的数据 + + 参数: + file_path (str): Parquet文件路径 + row_index (int): 要打印的行索引(默认为第0行) + """ + try: + # 读取Parquet文件 + df = pd.read_parquet(file_path) + + # 检查行索引是否有效 + if row_index >= len(df): + print(f"错误: 行索引 {row_index} 超出范围(文件共有 {len(df)} 行)") + return + + # 打印指定行数据 + print(f"文件: {file_path}") + print(f"第 {row_index} 行数据:\n{'-'*30}") + print(df.iloc[row_index]) + + except FileNotFoundError: + print(f"错误: 文件 {file_path} 不存在") + except Exception as e: + print(f"读取失败: {str(e)}") + +# 示例用法 +if __name__ == "__main__": + file_path = "example.parquet" # 替换为你的Parquet文件路径 + read_and_print_parquet_row("/home/jgl20/LYT/work/data/data/chunk-000/episode_000000.parquet", row_index=0) # 打印第0行 diff --git a/lerobot_aloha/replay_data.py b/lerobot_aloha/replay_data.py new file mode 100644 index 0000000..6c880dc --- /dev/null +++ b/lerobot_aloha/replay_data.py @@ -0,0 +1,112 @@ +#coding=utf-8 +import os +import numpy as np +import cv2 +import h5py +import argparse +import rospy + +from cv_bridge import CvBridge +from std_msgs.msg import Header +from sensor_msgs.msg import Image, JointState +from geometry_msgs.msg import Twist +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + + +def main(args): + rospy.init_node("replay_node") + bridge = CvBridge() + # img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10) + # img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10) + # img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10) + + # puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10) + # puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10) + + master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10) + master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10) + + # robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10) + + + # dataset_dir = args.dataset_dir + # episode_idx = args.episode_idx + # task_name = args.task_name + # dataset_name = f'episode_{episode_idx}' + + dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode]) + actions = dataset.hf_dataset.select_columns("action") + velocitys = dataset.hf_dataset.select_columns("observation.velocity") + efforts = dataset.hf_dataset.select_columns("observation.effort") + + origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279] + origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279] + + joint_state_msg = JointState() + joint_state_msg.header = Header() + joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称 + twist_msg = Twist() + + rate = rospy.Rate(args.fps) + + # qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name) + + + last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647] + last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] + rate = rospy.Rate(50) + for idx in range(len(actions)): + action = actions[idx]['action'].detach().cpu().numpy() + velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy() + effort = efforts[idx]['observation.effort'].detach().cpu().numpy() + if(rospy.is_shutdown()): + break + + new_actions = np.linspace(last_action, action, 5) # 插值 + new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值 + new_efforts = np.linspace(last_effort, effort, 5) # 插值 + last_action = action + last_velocity = velocity + last_effort = effort + for act in new_actions: + print(np.round(act[:7], 4)) + cur_timestamp = rospy.Time.now() # 设置时间戳 + joint_state_msg.header.stamp = cur_timestamp + + joint_state_msg.position = act[:7] + joint_state_msg.velocity = last_velocity[:7] + joint_state_msg.effort = last_effort[:7] + master_arm_left_publisher.publish(joint_state_msg) + + joint_state_msg.position = act[7:] + joint_state_msg.velocity = last_velocity[:7] + joint_state_msg.effort = last_effort[7:] + master_arm_right_publisher.publish(joint_state_msg) + + if(rospy.is_shutdown()): + break + rate.sleep() + + + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + # parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', + # default='/master/joint_left', required=False) + # parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', + # default='/master/joint_right', required=False) + + + args = parser.parse_args() + args.repo_id = "tangger/test" + args.root = "/home/ubuntu/LYT/aloha_lerobot/data1" + args.episode = 1 # replay episode + args.master_arm_left_topic = "/master/joint_left" + args.master_arm_right_topic = "/master/joint_right" + args.fps = 30 + + main(args) + # python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0 \ No newline at end of file diff --git a/lerobot_aloha/test.py b/lerobot_aloha/test.py new file mode 100644 index 0000000..8eb8748 --- /dev/null +++ b/lerobot_aloha/test.py @@ -0,0 +1,70 @@ +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.robot_devices.utils import busy_wait +import time +import argparse +from agilex_robot import AgilexRobot +import torch + +def get_arguments(): + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "tangger/test" + args.root = "./data2" + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 50 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "test test" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy= None + args.teleoprate = False + return args + + +cfg = get_arguments() +robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) +inference_time_s = 360 +fps = 30 +device = "cuda" # TODO: On Mac, use "mps" or "cpu" + +ckpt_path = "/home/ubuntu/LYT/lerobot/outputs/train/act_move_tube_on_scale/checkpoints/last/pretrained_model" +policy = ACTPolicy.from_pretrained(ckpt_path) +policy.to(device) + +for _ in range(inference_time_s * fps): + start_time = time.perf_counter() + + # Read the follower state and access the frames from the cameras + observation = robot.capture_observation() + if observation is None: + print("Observation is None, skipping...") + continue + + # Convert to pytorch format: channel first and float32 in [0,1] + # with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + # Remove batch dimension + action = action.squeeze(0) + # Move to cpu, if not already the case + action = action.to("cpu") + # Order the robot to move + robot.send_action(action) + + dt_s = time.perf_counter() - start_time + busy_wait(1 / fps - dt_s) \ No newline at end of file