From e0348815077589d287771cbf99fa6d4126cd23d0 Mon Sep 17 00:00:00 2001 From: tangger Date: Mon, 7 Apr 2025 20:32:39 +0800 Subject: [PATCH] restructure code --- lerobot_aloha/collect_data_lerobot.py | 461 ----------- lerobot_aloha/common/utils/__init__.py | 12 + .../__pycache__/__init__.cpython-310.pyc | Bin 0 -> 446 bytes .../__pycache__/control_utils.cpython-310.pyc | Bin 0 -> 6820 bytes .../__pycache__/data_utils.cpython-310.pyc | Bin 0 -> 2566 bytes .../__pycache__/replay_utils.cpython-310.pyc | Bin 0 -> 1000 bytes lerobot_aloha/common/utils/control_utils.py | 303 +++++++ lerobot_aloha/common/utils/data_utils.py | 105 +++ lerobot_aloha/common/utils/replay_utils.py | 32 + lerobot_aloha/inference.py | 769 ------------------ lerobot_aloha/main.py | 56 ++ lerobot_aloha/replay_data.py | 126 +-- 12 files changed, 533 insertions(+), 1331 deletions(-) delete mode 100644 lerobot_aloha/collect_data_lerobot.py create mode 100644 lerobot_aloha/common/utils/__init__.py create mode 100644 lerobot_aloha/common/utils/__pycache__/__init__.cpython-310.pyc create mode 100644 lerobot_aloha/common/utils/__pycache__/control_utils.cpython-310.pyc create mode 100644 lerobot_aloha/common/utils/__pycache__/data_utils.cpython-310.pyc create mode 100644 lerobot_aloha/common/utils/__pycache__/replay_utils.cpython-310.pyc create mode 100644 lerobot_aloha/common/utils/control_utils.py create mode 100644 lerobot_aloha/common/utils/data_utils.py create mode 100644 lerobot_aloha/common/utils/replay_utils.py delete mode 100644 lerobot_aloha/inference.py create mode 100644 lerobot_aloha/main.py diff --git a/lerobot_aloha/collect_data_lerobot.py b/lerobot_aloha/collect_data_lerobot.py deleted file mode 100644 index 8ee0a52..0000000 --- a/lerobot_aloha/collect_data_lerobot.py +++ /dev/null @@ -1,461 +0,0 @@ -import logging -import time -from dataclasses import asdict -from pprint import pformat -from pprint import pprint - -# from safetensors.torch import load_file, save_file -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.factory import make_policy -from lerobot.common.robot_devices.control_configs import ( - CalibrateControlConfig, - ControlPipelineConfig, - RecordControlConfig, - RemoteRobotConfig, - ReplayControlConfig, - TeleoperateControlConfig, -) -from lerobot.common.robot_devices.control_utils import ( - # init_keyboard_listener, - record_episode, - stop_recording, - is_headless -) -from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config -from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect -from lerobot.common.utils.utils import has_method, init_logging, log_say -from lerobot.common.utils.utils import get_safe_torch_device -from contextlib import nullcontext -from copy import copy -import torch -import rospy -import cv2 -from lerobot.configs import parser -from common.agilex_robot import AgilexRobot -from common.rosrobot_factory import RobotFactory - - -######################################################################################## -# Control modes -######################################################################################## - - -def predict_action(observation, policy, device, use_amp): - observation = copy(observation) - with ( - torch.inference_mode(), - torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), - ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension - for name in observation: - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - # Compute the next action with the policy - # based on the current observation - action = policy.select_action(observation) - - # Remove batch dimension - action = action.squeeze(0) - - # Move to cpu, if not already the case - action = action.to("cpu") - - return action - -def control_loop( - robot, - control_time_s=None, - teleoperate=False, - display_cameras=False, - dataset: LeRobotDataset | None = None, - events=None, - policy = None, - fps: int | None = None, - single_task: str | None = None, -): - # TODO(rcadene): Add option to record logs - # if not robot.is_connected: - # robot.connect() - - if events is None: - events = {"exit_early": False} - - if control_time_s is None: - control_time_s = float("inf") - - if dataset is not None and single_task is None: - raise ValueError("You need to provide a task as argument in `single_task`.") - - if dataset is not None and fps is not None and dataset.fps != fps: - raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") - - timestamp = 0 - start_episode_t = time.perf_counter() - rate = rospy.Rate(fps) - print_flag = True - while timestamp < control_time_s and not rospy.is_shutdown(): - # print(timestamp < control_time_s) - # print(rospy.is_shutdown()) - start_loop_t = time.perf_counter() - - if teleoperate: - observation, action = robot.teleop_step() - if observation is None or action is None: - if print_flag: - print("sync data fail, retrying...\n") - print_flag = False - rate.sleep() - continue - else: - # pass - observation = robot.capture_observation() - if policy is not None: - pred_action = predict_action( - observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp - ) - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - action = robot.send_action(pred_action) - action = {"action": action} - - if dataset is not None: - frame = {**observation, **action, "task": single_task} - dataset.add_frame(frame) - - # if display_cameras and not is_headless(): - # image_keys = [key for key in observation if "image" in key] - # for key in image_keys: - # if "depth" in key: - # pass - # else: - # cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) - - # print(1) - # cv2.waitKey(1) - - if display_cameras and not is_headless(): - image_keys = [key for key in observation if "image" in key] - - # 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整) - screen_width = 1920 - screen_height = 1080 - - # 计算窗口的排列方式 - num_images = len(image_keys) - max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640 - rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数 - columns = min(num_images, max_columns) # 实际使用的列数 - - # 遍历所有图像键并显示 - for idx, key in enumerate(image_keys): - if "depth" in key: - continue # 跳过深度图像 - - # 将图像从 RGB 转换为 BGR 格式 - image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) - - # 创建窗口 - cv2.imshow(key, image) - - # 计算窗口位置 - window_width = 640 - window_height = 480 - row = idx // max_columns - col = idx % max_columns - x_position = col * window_width - y_position = row * window_height - - # 移动窗口到指定位置 - cv2.moveWindow(key, x_position, y_position) - - # 等待 1 毫秒以处理事件 - cv2.waitKey(1) - - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - # log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_episode_t - if events["exit_early"]: - events["exit_early"] = False - break - - -def init_keyboard_listener(): - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. - events = {} - events["exit_early"] = False - events["record_start"] = False - events["rerecord_episode"] = False - events["stop_recording"] = False - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - listener = None - return listener, events - - # Only import pynput if not in a headless environment - from pynput import keyboard - - def on_press(key): - try: - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - events["exit_early"] = True - events["record_start"] = False - elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") - events["stop_recording"] = True - events["exit_early"] = True - elif key == keyboard.Key.up: - print("Up arrow pressed. Start data recording...") - events["record_start"] = True - - - except Exception as e: - print(f"Error handling key press: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - return listener, events - - -def stop_recording(robot, listener, display_cameras): - - if not is_headless(): - if listener is not None: - listener.stop() - - if display_cameras: - cv2.destroyAllWindows() - - -def record_episode( - robot, - dataset, - events, - episode_time_s, - display_cameras, - policy, - fps, - single_task, -): - control_loop( - robot=robot, - control_time_s=episode_time_s, - display_cameras=display_cameras, - dataset=dataset, - events=events, - policy=policy, - fps=fps, - teleoperate=policy is None, - single_task=single_task, - ) - - -def record( - robot, - cfg -) -> LeRobotDataset: - # TODO(rcadene): Add option to record logs - if cfg.resume: - dataset = LeRobotDataset( - cfg.repo_id, - root=cfg.root, - ) - if len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.num_image_writer_processes, - num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), - ) - # sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) - else: - # Create empty dataset or load existing saved episodes - # sanity_check_dataset_name(cfg.repo_id, cfg.policy) - dataset = LeRobotDataset.create( - cfg.repo_id, - cfg.fps, - root=cfg.root, - robot=None, - features=robot.features, - use_videos=cfg.video, - image_writer_processes=cfg.num_image_writer_processes, - image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), - ) - - # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - # policy = None - - # if not robot.is_connected: - # robot.connect() - - listener, events = init_keyboard_listener() - - # Execute a few seconds without recording to: - # 1. teleoperate the robot to move it in starting position if no policy provided, - # 2. give times to the robot devices to connect and start synchronizing, - # 3. place the cameras windows on screen - enable_teleoperation = policy is None - log_say("Warmup record", cfg.play_sounds) - print() - print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n") - # warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps) - - # if has_method(robot, "teleop_safety_stop"): - # robot.teleop_safety_stop() - - recorded_episodes = 0 - while True: - if recorded_episodes >= cfg.num_episodes: - break - - # if events["record_start"]: - log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) - pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}") - record_episode( - robot=robot, - dataset=dataset, - events=events, - episode_time_s=cfg.episode_time_s, - display_cameras=cfg.display_cameras, - policy=policy, - fps=cfg.fps, - single_task=cfg.single_task, - ) - - # Execute a few seconds without recording to give time to manually reset the environment - # Current code logic doesn't allow to teleoperate during this time. - # TODO(rcadene): add an option to enable teleoperation during reset - # Skip reset for the last episode to be recorded - if not events["stop_recording"] and ( - (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment", cfg.play_sounds) - pprint("Reset the environment, stop recording") - # reset_environment(robot, events, cfg.reset_time_s, cfg.fps) - - if events["rerecord_episode"]: - log_say("Re-record episode", cfg.play_sounds) - pprint("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded_episodes += 1 - - if events["stop_recording"]: - break - - log_say("Stop recording", cfg.play_sounds, blocking=True) - stop_recording(robot, listener, cfg.display_cameras) - - if cfg.push_to_hub: - dataset.push_to_hub(tags=cfg.tags, private=cfg.private) - - log_say("Exiting", cfg.play_sounds) - return dataset - - -def replay( - robot: AgilexRobot, - cfg, -): - # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset - # TODO(rcadene): Add option to record logs - - dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) - actions = dataset.hf_dataset.select_columns("action") - - # if not robot.is_connected: - # robot.connect() - - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() - - action = actions[idx]["action"] - robot.send_action(action) - - dt_s = time.perf_counter() - start_episode_t - busy_wait(1 / cfg.fps - dt_s) - - dt_s = time.perf_counter() - start_episode_t - # log_control_info(robot, dt_s, fps=cfg.fps) - - -import argparse -def get_arguments(): - parser = argparse.ArgumentParser() - args = parser.parse_args() - args.fps = 30 - args.resume = False - args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" - args.root = "./data5" - args.episode = 0 # replay episode - args.num_image_writer_processes = 0 - args.num_image_writer_threads_per_camera = 4 - args.video = True - args.num_episodes = 100 - args.episode_time_s = 30000 - args.play_sounds = False - args.display_cameras = True - args.single_task = "move the bottle from the right to the scale right" - args.use_depth_image = False - args.use_base = False - args.push_to_hub = False - args.policy = None - # args.teleoprate = True - args.control_type = "record" - # args.control_type = "replay" - return args - - - - -# @parser.wrap() -def control_robot(cfg): - # 使用工厂模式创建机器人实例 - robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg) - - if cfg.control_type == "record": - record(robot, cfg) - elif cfg.control_type == "replay": - replay(robot, cfg) - - -if __name__ == "__main__": - cfg = get_arguments() - control_robot(cfg) - # control_robot() - # 使用工厂模式创建机器人实例 - # robot = RobotFactory.create(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg) - # print(robot.features.items()) - # print([key for key, ft in robot.features.items() if ft["dtype"] == "video"]) - # record(robot, cfg) - # capture = robot.capture_observation() - # import torch - # torch.save(capture, "test.pt") - # action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094, - # 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]], - # device='cpu') - # robot.send_action(action.squeeze(0)) - # print() \ No newline at end of file diff --git a/lerobot_aloha/common/utils/__init__.py b/lerobot_aloha/common/utils/__init__.py new file mode 100644 index 0000000..c7e4a8f --- /dev/null +++ b/lerobot_aloha/common/utils/__init__.py @@ -0,0 +1,12 @@ +# Import utility functions for easy access +from .control_utils import ( + predict_action, + control_loop, + init_keyboard_listener, + stop_recording, + record_episode, + is_headless, + busy_wait +) +from .data_utils import record +from .replay_utils import replay diff --git a/lerobot_aloha/common/utils/__pycache__/__init__.cpython-310.pyc b/lerobot_aloha/common/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f1ade86f55cfd68ba9c5bb12fd65a4c62ab3a0e GIT binary patch literal 446 zcmZ9I!Ait15QdYs+irK;MG$-fFTL0o5ZQ}2@uC-V32hU0gGr|(Dbh#sDSVGyJ$d%# z$<9_m9LUE%!_Q1+pe#3Jj{V!`;hYlkj>+GM=43xVhY&fmThO9JdkSjG0z}X7IuS`LUo=uombF@qJ!RB&|Cuy;X?Jw z^Jh7=&{)B>a$fe*hMn^;a4S2oLiLBRTO^l*av)^bD#wqq5n4JI@^*9+e`%DjXh;|5 zq6Xbe_f^5l-=U*d+F4xH#)t&5LZpZcv7Qk&qVWx1>^T++E3w;82~YV-Y>9rv*KK_S zBkQppdq389Pmi^h7FwA1)^Ke8WgQp;eSKcJHs5>FF5`DSnIkaM|l*HHkgS%2~r8w9eYHxF3?~jijvNNfxussUq2A zPxq+0n-m8<3nMhL_QBZAEr7ECdGKNXf&e)z@(<(|R%+Zu0*mdaw7rB04YhU~vtI#&vde#UM-|vS(EFQ#9O*x5Ch0XAfi|ehcqc*oNyzf`$0LJoF-wbCo(4 zxViD5UCC?E|>0MR!xT7jYU@DB8p{r)HX6B-p?!<(zOceEdUV{CCVwGK3SGu2Y-qIP0}x_YW}>k=qZ>!wQ4 zUgH(g_fzx4K0Vu0=b%4B>!e26^5_b5r;%!>`nZx-PMSokN)ETF1Wl@2nx!3^x)#iR zLyL9~^-`~to=tkw%6-ziGj-5^FilT1ZO6R^p(xGyffU^9$IR`=UKp?~7zbx4gjT87DHWU=w9alNQj%6rHBNLj=d;d!p!l&t$b~=Y%ncE$ zOlHaQ+N3jyM9&=z*V)5eLHf($kfDvaNgNI^-98(5kFkWI6n!tkk|rZn;Tfw@w%>Um zf+&=jPc99N&*h=U07DY)613!V>BUDwkyX5byCXh-0yX=ba6*~S=VU{f*&imG?HOrO zrJ}x&wY*?kKpuV3qt#}XVq(*@`!1}BSY)+rKXl{QZe+Egkb?xK=8#c%J4u*CSv?7& zM~M(40kaFcGm-E`A9kg7E~^!N?P}JT5_wk5Wt~;?$!E1ff|(V#1Cdqp;mUaPge)y< zL4Kqc1^26Hi>@KDXR3zvgc;Va?TW3h>TCMEzNkBvV>Zoe|86Ze|IJ#S!z5&h-b))@0*=PBD4TTBL{v||_R=A#M(&h#?J+cN})#qX$DePEJE!1^+ zIki>$VG}8_KdW)Z_3v}QN^g*MeS3K+i4AH{TFHscgFSDJh6c0 zPFy*urZreygEz*_w4Unw+Rt_QEwtB8QHuVkpqxn!8Nvs^(*N8Tx5l$6yjP+{ZVK;( znFnc8dU2gMk729a1h2-4Lw@1Wt>7xgHSo1oT2-_BByOfC$1`blueCROtgC$P*c#7) z@-J)Qkb=dINC^w0G1h15EkPcNL}i>ZlO+wsCN&DrL& zeC|Wgze5-}&o7J@`22V=UQFjvUy7H~1-@{J`j*qBba_SNi>q3?yrzv;V%2hNj923o zq6Qv*6YIS=UgJyY8Zhyn9{oeSdU6r-UjqWJooeyw-lbDxdPzS=Gib z#V_#1v<4J@X%(~AwKIKcM^guMb!d}uk@oQhkR0Wkh<5I4lihtQ-(C4GwNAh775CE; zC*O6wfE6B!`C&L@+o5F2v#}7?s*v)zz^LffnTkM|7euiOAebOi>G)!AAbJsyGk1iH z>Y!$k00oaG^f3%S|E72b2cPy$V_y+T^!hM)>EaVp=vhV?c^}&rV(L?c{jRejZpv4U z`A801p3F!b0O`T@Hsy$06ePLwXhQInfDdd~XB!f8ix@;%tambY3N_#qfE|pGOLcE% zj(F_FJ>g3KXd^QLvZFWN4-*z(733O*GCc4&+yf)fVQ$1+xs&A4R=|Ed#rnsc(TxWP z$chy+Oo~OjVd8VPg)r?=;`+o;qJ-!N+=<|t{5qg(WU&{&#YRTEqnyBl(dFnU=qpZa z8=>1P7!}JS2&~iTIPzr-0sa#RlDjA$@&FXX+{YWOBH2L@@gsYTU_8}fYDTL6mGQ`c z`_|3dO(2y6J}$HbEOuWF-GiF>5Potr+~j~~5>LP_j> zP}4N;4NS%R0Au6=#$^>QhVicVFPLzfWfkS8GUtb`pNQL1hEmA^i-N{xP3+-zuOA`; z6f&#GFhVS6-J{5nBD!cdiFtS!WQ|-`dywWZs|+PV&a4vo0=0#{J4ARZdQ+ZExfEHU z*;HW<+un|R5f%9owKO6T@WNSj>n`WLZPF(d8fK<10@95*GY4Lf)kOdefW-=U>Vq4Z z-9LzLg?@;+q62fCx%J-N_wM!Xedp~PZ-3`rR`mu@gu~1kga_iG7hv*P?a=k&JK`v- zZza)Dk816&(OMDdX05XZT`n~Li+ZEolN5X<6t!V&(%W!C73`H{8_d#A&*Zw{=F(f(8Ir~0-yeEFzohL=t|DW*hK)SB_~!$mE|fp<-?v0mXo0mmvgNg z{z@^X&0b|NhI72#fEWGeJ;H`;LoQhpF*+Ox005eu%}K<^s$S@L){0_ia^48=9QD6j zo{pza&0NF?g9Y!f_kve*-DQQJdwPP#R3?k}f?*Oz?9lUl7KAau`{ab(*;EjYVBaq3 zp8;vQS0PZ0{u#;WJ_N4sX?t{*T#gN{tBB7iw%J0C?Zv4KxL#12p!|JMRwtA-C_*vf z45bCiZ-dgFP}ZMASpnrAgVLE$PQIFRw6wT-G|+Klp=+ZSC>yE%srCzl=pWJz1$ur0 zR|9oZr9=HtYd1GX*X|Jtz-z)}NrZH;if~5ku-kCF*cBBF!LL0~d%7VXl2C5n72Cgc z5F-UCH(ebseVl|?vHM^t|CD%jZb$G2N_n4t{5~{eNM9#+4JJiELWAC?LG3>o7IU7= z_daRav-G1i<-pi3X5$mb$s*TTyOx=E0P+<{LXdBOb7uKiT4suo?{tz545xby>GRFrJLA zc`##wYh`4;(gL44@RB*A1A?P=E%$vjKit{DvdCk}0EKRKcti%EW|!G_3nYN&CHWwp zuuI;GPd@*x$s(42^F02Jmc)=)&g62&uz;f06ANhCG+RbAnFq~mJgK)GjHLI zI!Fi$qXl*3(umG6~BcqiY22(k;s728|B zpGO=~+mzo!kKCo?E+tB{ChDd1{y)$jwUAWnv<>9`tS>D~_4Vn-w|uBfLzU1ef}$I; z6!%=Io*HQ-t>!0ioTH6T=uVgJSWWlO5dNtt-%|$1T$TsW6Ps0?D*4|?_y6jUF^P?e zbJXoTR(kF^>bX0HIR7h8MgL!?nRY$jDFb5MxB^4Rop%R6MFm;MP+IBu_ zLYuYDN+M^)XxJpQJf}Tb>#XmE!D-PlMi?uInx=kLl$|f+0A0^G#)%cbiQW`Sv#-l4 zM+6fWm*sh$>bPFe_BDB>Ui6eL(L7Pp+crl-rAKcgyk-<$t55|BH^jZ(wukc<;?gZd zFyVH2K7Ouk7w6}7wt0RCo#=?U`c!@a4j8{dl6nmZ(6Ckv&YV2BrhxpM?BsVr`Z5WO zVBW4#s5u8Qtv*pbZF_TLW0RitlH|l%@n53SA5x+WUS0IzFALw>Qg?ys^052$;_|$c zUzB$WONlz=rFzjI>EPI1J4c;gq^ne&n-t3A$0*vN%*4H;(jW!O3MACVs%cLxa)L3L zeG`|Li7(#7YV&_YY^}umCY%`kY}1}M+S57%FYraRVtiS(jsL3FoiFUf{-%AQ`SSk& D*x@6& literal 0 HcmV?d00001 diff --git a/lerobot_aloha/common/utils/__pycache__/data_utils.cpython-310.pyc b/lerobot_aloha/common/utils/__pycache__/data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd02da3936056a34d0d9b78de1a6ab23fb6d65e6 GIT binary patch literal 2566 zcmaJ@&2Jk;6yGneH@4$^)Jc;-+vQu9N+JOgqN)&8DO{=^(q330mW^j@Z?Yfm%xp>` z2L+|56dDy&+Dgy}3Q{PS0s@tOQt=OP;kM#N9AYjUxKv0W-psCBCscSfGjHB|^WK}8 z-+SY*w>Kxj@3kNPtUXthq+j`@`=`R>1!!~{IwT?4ldNQvEg9yDr}&Dc$egEos;^m^ zpfxY!>y|F)jFOwoS(Pye!(jEy;d*qvRJDs$Ur^EHLHl^!%{id27l6_ zvK;G;25|!hRPi9545vc&qQe}DIVbNsGuUo~o?DqKD{i;a-N8AVN0W;UOLomuY-8 z+uoF#N;AW4LK|tS(DNOPqTXdgmYNw<9F^#5VjuO5NHY?>0NnmjiL8RB0nqJu7~~dH z*^eOO7CW+SBGwX@?v)T9hxUT!Psc98?vB0=)B~Jq>ZtEoX+Z}$_$5CD@=*sY4RcE< zhqf=P5c7i&vj&`zC78jugZvKr-2v22XuHq|zY|zB3NzuqsQ?g6!OOQUjLFC&F z5>_y!n8tZNWp#obL=Tm9!eJk0MD(#y#SY{KXyqb`?O7M$kV2yQP7T|~h|4hPvQ*l{ zPfeC(D~m*yI?i^hopIBZ$ z$M=)kK5%&d7VY-pV(Y~9B$8C`$oyaeGbl-&`!eTCy8&R#@j4u8y3{Ni1ZuVvXE_w1 zBE~wN#Q~#nF}-J-xjwdO+=pD+@SHij0%t-T3Pn}a%vh~9;G!wG@~{myW{3d>+Z+&a zKf#*|;{4Dg1}`S7V-p9nE(rr4+|56BU-oX3m)8_};9!1W5}=wtIEUYW(u^JA z&CM4NZ_)>J8dF|anSr$cPK`M=8fW1xm~|MovIk%xmzcav(+IZ3pQ<=Za3i!`l%Q4h za1duZn;wKzBh2mwV?4$|-p%QP?IZt@rO>q-0Pn8GT~uS5NyW_>&}*0CG`C;n24C`(FqPwTae?8Ieurn zZu0*g^iAmWK1J5x)Bh=Id@KT`ENVKZH93(9?wy^5oT{s1vZg9>4kRUI1p4H=r>oyF z3d3?iF=Rt2$k2;Vff?t3_AmX9gAZj8bTpYf10j_A#Q%pHfC|6{c_S_$ej?rBc(R-E zG}Y*MXXhzHz|Y#XMO&nqi|KgPsj!gDC4d;)qLmzr=midl4)IuHsmDYLZwn)#mmHx0 tlep*vAUu9?LtWj|z>AH&euyFuKhHCo^4m2eLoSG~UoOdts;I@k{{{2pL7e~q literal 0 HcmV?d00001 diff --git a/lerobot_aloha/common/utils/__pycache__/replay_utils.cpython-310.pyc b/lerobot_aloha/common/utils/__pycache__/replay_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4788ef5fd791cced7c808e743bdbf02b2bf7a95f GIT binary patch literal 1000 zcmZ8g&1=*^6rag=*X>qBN=Ke2-%Iq( zyTBNM-3PxAuw!v#&;{N{Z^_D2YgdD7wDMcG^|qmzZL|tnbak%{u1%~xZT&5@3e}JX zz`MSS+HeaQrgu+h2<=7ZUj_u$5vz;jj0l!;K^Y~&bV?|+N|Myr@+u>WlayD>JZnUv za$Zers8ugTCTDhM*lhHS*m04(k}9cE_U|-ZW;3#Pmx%DiDNB|07(3b%oGTr%I+vU> zse>dn>CPW?Rhm-zGjcv6OTkN9DO?KRW3G|{j7iDljND+D|KWHf%mwLCu$sp?6`n;s zi5{FRW78@S=#gXvloF@BXv#`zzZQ9=^q^|WxFglRNUDr!Pvs@k!-6+9}F29|O drM7L|QM)Ej^sMBxDcEy!chdY#5gKB6+`s+`9D4u& literal 0 HcmV?d00001 diff --git a/lerobot_aloha/common/utils/control_utils.py b/lerobot_aloha/common/utils/control_utils.py new file mode 100644 index 0000000..11ae4d9 --- /dev/null +++ b/lerobot_aloha/common/utils/control_utils.py @@ -0,0 +1,303 @@ +import logging +import time +import torch +import rospy +import cv2 +from contextlib import nullcontext +from copy import copy +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.utils import get_safe_torch_device + + +def is_headless(): + """ + Check if the environment is headless (no display available). + + Returns: + bool: True if the environment is headless, False otherwise. + """ + try: + import tkinter as tk + root = tk.Tk() + root.withdraw() + root.update() + root.destroy() + return False + except: + return True + + +def predict_action(observation, policy, device, use_amp): + """ + Predict action based on observation using the policy. + + Args: + observation: Current observation + policy: Policy model + device: Torch device + use_amp: Whether to use automatic mixed precision + + Returns: + torch.Tensor: Predicted action + """ + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + + +def control_loop( + robot, + control_time_s=None, + teleoperate=False, + display_cameras=False, + dataset: LeRobotDataset | None = None, + events=None, + policy = None, + fps: int | None = None, + single_task: str | None = None, +): + """ + Main control loop for robot operation. + + Args: + robot: Robot instance + control_time_s: Control time in seconds + teleoperate: Whether to use teleoperation + display_cameras: Whether to display camera feeds + dataset: Dataset for recording + events: Event dictionary + policy: Policy model + fps: Frames per second + single_task: Task name + """ + if events is None: + events = {"exit_early": False} + + if control_time_s is None: + control_time_s = float("inf") + + if dataset is not None and single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + if dataset is not None and fps is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + rate = rospy.Rate(fps) + print_flag = True + while timestamp < control_time_s and not rospy.is_shutdown(): + start_loop_t = time.perf_counter() + + if teleoperate: + observation, action = robot.teleop_step() + if observation is None or action is None: + if print_flag: + print("sync data fail, retrying...\n") + print_flag = False + rate.sleep() + continue + else: + observation = robot.capture_observation() + if policy is not None: + pred_action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + action = robot.send_action(pred_action) + action = {"action": action} + + if dataset is not None: + frame = {**observation, **action, "task": single_task} + dataset.add_frame(frame) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + + # 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整) + screen_width = 1920 + screen_height = 1080 + + # 计算窗口的排列方式 + num_images = len(image_keys) + max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640 + rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数 + columns = min(num_images, max_columns) # 实际使用的列数 + + # 遍历所有图像键并显示 + for idx, key in enumerate(image_keys): + if "depth" in key: + continue # 跳过深度图像 + + # 将图像从 RGB 转换为 BGR 格式 + image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + + # 创建窗口 + cv2.imshow(key, image) + + # 计算窗口位置 + window_width = 640 + window_height = 480 + row = idx // max_columns + col = idx % max_columns + x_position = col * window_width + y_position = row * window_height + + # 移动窗口到指定位置 + cv2.moveWindow(key, x_position, y_position) + + # 等待 1 毫秒以处理事件 + cv2.waitKey(1) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +def init_keyboard_listener(): + """ + Initialize keyboard listener for control events. + + Returns: + tuple: (listener, events) - Keyboard listener and events dictionary + """ + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["record_start"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + events["record_start"] = False + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + elif key == keyboard.Key.up: + print("Up arrow pressed. Start data recording...") + events["record_start"] = True + + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def stop_recording(robot, listener, display_cameras): + """ + Stop recording and clean up resources. + + Args: + robot: Robot instance + listener: Keyboard listener + display_cameras: Whether cameras are being displayed + """ + if not is_headless(): + if listener is not None: + listener.stop() + + if display_cameras: + cv2.destroyAllWindows() + + +def record_episode( + robot, + dataset, + events, + episode_time_s, + display_cameras, + policy, + fps, + single_task, +): + """ + Record a single episode. + + Args: + robot: Robot instance + dataset: Dataset for recording + events: Event dictionary + episode_time_s: Episode time in seconds + display_cameras: Whether to display camera feeds + policy: Policy model + fps: Frames per second + single_task: Task name + """ + control_loop( + robot=robot, + control_time_s=episode_time_s, + display_cameras=display_cameras, + dataset=dataset, + events=events, + policy=policy, + fps=fps, + teleoperate=policy is None, + single_task=single_task, + ) + + +def busy_wait(seconds): + """ + Busy wait for a specified number of seconds. + + Args: + seconds: Number of seconds to wait + """ + if seconds <= 0: + return + start_time = time.perf_counter() + while time.perf_counter() - start_time < seconds: + pass diff --git a/lerobot_aloha/common/utils/data_utils.py b/lerobot_aloha/common/utils/data_utils.py new file mode 100644 index 0000000..040c8d4 --- /dev/null +++ b/lerobot_aloha/common/utils/data_utils.py @@ -0,0 +1,105 @@ +import logging +import time +from pprint import pprint +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils.utils import log_say, has_method +from common.utils.control_utils import init_keyboard_listener, stop_recording, record_episode + + +def record( + robot, + cfg +) -> LeRobotDataset: + """ + Record robot data according to configuration. + + Args: + robot: Robot instance + cfg: Configuration object + + Returns: + LeRobotDataset: Dataset with recorded episodes + """ + # Initialize or load dataset + if cfg.resume: + dataset = LeRobotDataset( + cfg.repo_id, + root=cfg.root, + ) + if len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.num_image_writer_processes, + num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + else: + # Create empty dataset or load existing saved episodes + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.root, + robot=None, + features=robot.features, + use_videos=cfg.video, + image_writer_processes=cfg.num_image_writer_processes, + image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + + # Initialize keyboard listener + listener, events = init_keyboard_listener() + + # Print recording instructions + print() + print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n") + + # Record episodes + recorded_episodes = 0 + while True: + if recorded_episodes >= cfg.num_episodes: + break + + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}") + record_episode( + robot=robot, + dataset=dataset, + events=events, + episode_time_s=cfg.episode_time_s, + display_cameras=cfg.display_cameras, + policy=policy, + fps=cfg.fps, + single_task=cfg.single_task, + ) + + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + pprint("Reset the environment, stop recording") + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + pprint("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 + + if events["stop_recording"]: + break + + log_say("Stop recording", cfg.play_sounds, blocking=True) + stop_recording(robot, listener, cfg.display_cameras) + + if cfg.push_to_hub: + dataset.push_to_hub(tags=cfg.tags, private=cfg.private) + + log_say("Exiting", cfg.play_sounds) + return dataset diff --git a/lerobot_aloha/common/utils/replay_utils.py b/lerobot_aloha/common/utils/replay_utils.py new file mode 100644 index 0000000..8b08d40 --- /dev/null +++ b/lerobot_aloha/common/utils/replay_utils.py @@ -0,0 +1,32 @@ +import time +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from common.utils.control_utils import busy_wait + + +def replay( + robot, + cfg, +): + """ + Replay recorded robot data according to configuration. + + Args: + robot: Robot instance + cfg: Configuration object + """ + # Load dataset + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) + actions = dataset.hf_dataset.select_columns("action") + + print(f"Replaying episode {cfg.episode} from dataset {cfg.repo_id}") + print(f"Total frames: {dataset.num_frames}") + + # Replay each frame + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / cfg.fps - dt_s) diff --git a/lerobot_aloha/inference.py b/lerobot_aloha/inference.py deleted file mode 100644 index 34f7f52..0000000 --- a/lerobot_aloha/inference.py +++ /dev/null @@ -1,769 +0,0 @@ -#!/home/lin/software/miniconda3/envs/aloha/bin/python -# -- coding: UTF-8 -""" -#!/usr/bin/python3 -""" - -import torch -import numpy as np -import os -import pickle -import argparse -from einops import rearrange -import collections -from collections import deque - -import rospy -from std_msgs.msg import Header -from geometry_msgs.msg import Twist -from sensor_msgs.msg import JointState, Image -from nav_msgs.msg import Odometry -from cv_bridge import CvBridge -import time -import threading -import math -import threading - - - - -import sys -sys.path.append("./") - -SEED = 42 -torch.manual_seed(SEED) -np.random.seed(SEED) - -task_config = {'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']} - -inference_thread = None -inference_lock = threading.Lock() -inference_actions = None -inference_timestep = None - - -def actions_interpolation(args, pre_action, actions, stats): - steps = np.concatenate((np.array(args.arm_steps_length), np.array(args.arm_steps_length)), axis=0) - pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] - post_process = lambda a: a * stats['action_std'] + stats['action_mean'] - result = [pre_action] - post_action = post_process(actions[0]) - # print("pre_action:", pre_action[7:]) - # print("actions_interpolation1:", post_action[:, 7:]) - max_diff_index = 0 - max_diff = -1 - for i in range(post_action.shape[0]): - diff = 0 - for j in range(pre_action.shape[0]): - if j == 6 or j == 13: - continue - diff += math.fabs(pre_action[j] - post_action[i][j]) - if diff > max_diff: - max_diff = diff - max_diff_index = i - - for i in range(max_diff_index, post_action.shape[0]): - step = max([math.floor(math.fabs(result[-1][j] - post_action[i][j])/steps[j]) for j in range(pre_action.shape[0])]) - inter = np.linspace(result[-1], post_action[i], step+2) - result.extend(inter[1:]) - while len(result) < args.chunk_size+1: - result.append(result[-1]) - result = np.array(result)[1:args.chunk_size+1] - # print("actions_interpolation2:", result.shape, result[:, 7:]) - result = pre_process(result) - result = result[np.newaxis, :] - return result - - -def get_model_config(args): - # 设置随机种子,你可以确保在相同的初始条件下,每次运行代码时生成的随机数序列是相同的。 - set_seed(1) - - # 如果是ACT策略 - # fixed parameters - if args.policy_class == 'ACT': - policy_config = {'lr': args.lr, - 'lr_backbone': args.lr_backbone, - 'backbone': args.backbone, - 'masks': args.masks, - 'weight_decay': args.weight_decay, - 'dilation': args.dilation, - 'position_embedding': args.position_embedding, - 'loss_function': args.loss_function, - 'chunk_size': args.chunk_size, # 查询 - 'camera_names': task_config['camera_names'], - 'use_depth_image': args.use_depth_image, - 'use_robot_base': args.use_robot_base, - 'kl_weight': args.kl_weight, # kl散度权重 - 'hidden_dim': args.hidden_dim, # 隐藏层维度 - 'dim_feedforward': args.dim_feedforward, - 'enc_layers': args.enc_layers, - 'dec_layers': args.dec_layers, - 'nheads': args.nheads, - 'dropout': args.dropout, - 'pre_norm': args.pre_norm - } - elif args.policy_class == 'CNNMLP': - policy_config = {'lr': args.lr, - 'lr_backbone': args.lr_backbone, - 'backbone': args.backbone, - 'masks': args.masks, - 'weight_decay': args.weight_decay, - 'dilation': args.dilation, - 'position_embedding': args.position_embedding, - 'loss_function': args.loss_function, - 'chunk_size': 1, # 查询 - 'camera_names': task_config['camera_names'], - 'use_depth_image': args.use_depth_image, - 'use_robot_base': args.use_robot_base - } - - elif args.policy_class == 'Diffusion': - policy_config = {'lr': args.lr, - 'lr_backbone': args.lr_backbone, - 'backbone': args.backbone, - 'masks': args.masks, - 'weight_decay': args.weight_decay, - 'dilation': args.dilation, - 'position_embedding': args.position_embedding, - 'loss_function': args.loss_function, - 'chunk_size': args.chunk_size, # 查询 - 'camera_names': task_config['camera_names'], - 'use_depth_image': args.use_depth_image, - 'use_robot_base': args.use_robot_base, - 'observation_horizon': args.observation_horizon, - 'action_horizon': args.action_horizon, - 'num_inference_timesteps': args.num_inference_timesteps, - 'ema_power': args.ema_power - } - else: - raise NotImplementedError - - config = { - 'ckpt_dir': args.ckpt_dir, - 'ckpt_name': args.ckpt_name, - 'ckpt_stats_name': args.ckpt_stats_name, - 'episode_len': args.max_publish_step, - 'state_dim': args.state_dim, - 'policy_class': args.policy_class, - 'policy_config': policy_config, - 'temporal_agg': args.temporal_agg, - 'camera_names': task_config['camera_names'], - } - return config - - -def make_policy(policy_class, policy_config): - if policy_class == 'ACT': - policy = ACTPolicy(policy_config) - elif policy_class == 'CNNMLP': - policy = CNNMLPPolicy(policy_config) - elif policy_class == 'Diffusion': - policy = DiffusionPolicy(policy_config) - else: - raise NotImplementedError - return policy - - -def get_image(observation, camera_names): - curr_images = [] - for cam_name in camera_names: - curr_image = rearrange(observation['images'][cam_name], 'h w c -> c h w') - - curr_images.append(curr_image) - curr_image = np.stack(curr_images, axis=0) - curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) - return curr_image - - -def get_depth_image(observation, camera_names): - curr_images = [] - for cam_name in camera_names: - curr_images.append(observation['images_depth'][cam_name]) - curr_image = np.stack(curr_images, axis=0) - curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0) - return curr_image - - -def inference_process(args, config, ros_operator, policy, stats, t, pre_action): - global inference_lock - global inference_actions - global inference_timestep - print_flag = True - pre_pos_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] - pre_action_process = lambda next_action: (next_action - stats["action_mean"]) / stats["action_std"] - rate = rospy.Rate(args.publish_rate) - while True and not rospy.is_shutdown(): - result = ros_operator.get_frame() - if not result: - if print_flag: - print("syn fail") - print_flag = False - rate.sleep() - continue - print_flag = True - (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, - puppet_arm_left, puppet_arm_right, robot_base) = result - obs = collections.OrderedDict() - image_dict = dict() - - image_dict[config['camera_names'][0]] = img_front - image_dict[config['camera_names'][1]] = img_left - image_dict[config['camera_names'][2]] = img_right - - - obs['images'] = image_dict - - if args.use_depth_image: - image_depth_dict = dict() - image_depth_dict[config['camera_names'][0]] = img_front_depth - image_depth_dict[config['camera_names'][1]] = img_left_depth - image_depth_dict[config['camera_names'][2]] = img_right_depth - obs['images_depth'] = image_depth_dict - - obs['qpos'] = np.concatenate( - (np.array(puppet_arm_left.position), np.array(puppet_arm_right.position)), axis=0) - obs['qvel'] = np.concatenate( - (np.array(puppet_arm_left.velocity), np.array(puppet_arm_right.velocity)), axis=0) - obs['effort'] = np.concatenate( - (np.array(puppet_arm_left.effort), np.array(puppet_arm_right.effort)), axis=0) - if args.use_robot_base: - obs['base_vel'] = [robot_base.twist.twist.linear.x, robot_base.twist.twist.angular.z] - obs['qpos'] = np.concatenate((obs['qpos'], obs['base_vel']), axis=0) - else: - obs['base_vel'] = [0.0, 0.0] - # qpos_numpy = np.array(obs['qpos']) - - # 归一化处理qpos 并转到cuda - qpos = pre_pos_process(obs['qpos']) - qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0) - # 当前图像curr_image获取图像 - curr_image = get_image(obs, config['camera_names']) - curr_depth_image = None - if args.use_depth_image: - curr_depth_image = get_depth_image(obs, config['camera_names']) - start_time = time.time() - all_actions = policy(curr_image, curr_depth_image, qpos) - end_time = time.time() - print("model cost time: ", end_time -start_time) - inference_lock.acquire() - inference_actions = all_actions.cpu().detach().numpy() - if pre_action is None: - pre_action = obs['qpos'] - # print("obs['qpos']:", obs['qpos'][7:]) - if args.use_actions_interpolation: - inference_actions = actions_interpolation(args, pre_action, inference_actions, stats) - inference_timestep = t - inference_lock.release() - break - - -def model_inference(args, config, ros_operator, save_episode=True): - global inference_lock - global inference_actions - global inference_timestep - global inference_thread - set_seed(1000) - - # 1 创建模型数据 继承nn.Module - policy = make_policy(config['policy_class'], config['policy_config']) - # print("model structure\n", policy.model) - - # 2 加载模型权重 - ckpt_path = os.path.join(config['ckpt_dir'], config['ckpt_name']) - state_dict = torch.load(ckpt_path) - new_state_dict = {} - for key, value in state_dict.items(): - if key in ["model.is_pad_head.weight", "model.is_pad_head.bias"]: - continue - if key in ["model.input_proj_next_action.weight", "model.input_proj_next_action.bias"]: - continue - new_state_dict[key] = value - loading_status = policy.deserialize(new_state_dict) - if not loading_status: - print("ckpt path not exist") - return False - - # 3 模型设置为cuda模式和验证模式 - policy.cuda() - policy.eval() - - # 4 加载统计值 - stats_path = os.path.join(config['ckpt_dir'], config['ckpt_stats_name']) - # 统计的数据 # 加载action_mean, action_std, qpos_mean, qpos_std 14维 - with open(stats_path, 'rb') as f: - stats = pickle.load(f) - - # 数据预处理和后处理函数定义 - pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std'] - post_process = lambda a: a * stats['action_std'] + stats['action_mean'] - - max_publish_step = config['episode_len'] - chunk_size = config['policy_config']['chunk_size'] - - # 发布基础的姿态 - left0 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, 3.557830810546875] - right0 = [-0.00133514404296875, 0.00438690185546875, 0.034523963928222656, -0.053597450256347656, -0.00476837158203125, -0.00209808349609375, 3.557830810546875] - left1 = [-0.00133514404296875, 0.00209808349609375, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3393220901489258] - right1 = [-0.00133514404296875, 0.00247955322265625, 0.01583099365234375, -0.032616615295410156, -0.00286102294921875, 0.00095367431640625, -0.3397035598754883] - - ros_operator.puppet_arm_publish_continuous(left0, right0) - input("Enter any key to continue :") - ros_operator.puppet_arm_publish_continuous(left1, right1) - action = None - # 推理 - with torch.inference_mode(): - while True and not rospy.is_shutdown(): - # 每个回合的步数 - t = 0 - max_t = 0 - rate = rospy.Rate(args.publish_rate) - if config['temporal_agg']: - all_time_actions = np.zeros([max_publish_step, max_publish_step + chunk_size, config['state_dim']]) - while t < max_publish_step and not rospy.is_shutdown(): - # start_time = time.time() - # query policy - if config['policy_class'] == "ACT": - if t >= max_t: - pre_action = action - inference_thread = threading.Thread(target=inference_process, - args=(args, config, ros_operator, - policy, stats, t, pre_action)) - inference_thread.start() - inference_thread.join() - inference_lock.acquire() - if inference_actions is not None: - inference_thread = None - all_actions = inference_actions - inference_actions = None - max_t = t + args.pos_lookahead_step - if config['temporal_agg']: - all_time_actions[[t], t:t + chunk_size] = all_actions - inference_lock.release() - if config['temporal_agg']: - actions_for_curr_step = all_time_actions[:, t] - actions_populated = np.all(actions_for_curr_step != 0, axis=1) - actions_for_curr_step = actions_for_curr_step[actions_populated] - k = 0.01 - exp_weights = np.exp(-k * np.arange(len(actions_for_curr_step))) - exp_weights = exp_weights / exp_weights.sum() - exp_weights = exp_weights[:, np.newaxis] - raw_action = (actions_for_curr_step * exp_weights).sum(axis=0, keepdims=True) - else: - if args.pos_lookahead_step != 0: - raw_action = all_actions[:, t % args.pos_lookahead_step] - else: - raw_action = all_actions[:, t % chunk_size] - else: - raise NotImplementedError - action = post_process(raw_action[0]) - left_action = action[:7] # 取7维度 - right_action = action[7:14] - ros_operator.puppet_arm_publish(left_action, right_action) # puppet_arm_publish_continuous_thread - if args.use_robot_base: - vel_action = action[14:16] - ros_operator.robot_base_publish(vel_action) - t += 1 - # end_time = time.time() - # print("publish: ", t) - # print("time:", end_time - start_time) - # print("left_action:", left_action) - # print("right_action:", right_action) - rate.sleep() - - -class RosOperator: - def __init__(self, args): - self.robot_base_deque = None - self.puppet_arm_right_deque = None - self.puppet_arm_left_deque = None - self.img_front_deque = None - self.img_right_deque = None - self.img_left_deque = None - self.img_front_depth_deque = None - self.img_right_depth_deque = None - self.img_left_depth_deque = None - self.bridge = None - self.puppet_arm_left_publisher = None - self.puppet_arm_right_publisher = None - self.robot_base_publisher = None - self.puppet_arm_publish_thread = None - self.puppet_arm_publish_lock = None - self.args = args - self.ctrl_state = False - self.ctrl_state_lock = threading.Lock() - self.init() - self.init_ros() - - def init(self): - self.bridge = CvBridge() - self.img_left_deque = deque() - self.img_right_deque = deque() - self.img_front_deque = deque() - self.img_left_depth_deque = deque() - self.img_right_depth_deque = deque() - self.img_front_depth_deque = deque() - self.puppet_arm_left_deque = deque() - self.puppet_arm_right_deque = deque() - self.robot_base_deque = deque() - self.puppet_arm_publish_lock = threading.Lock() - self.puppet_arm_publish_lock.acquire() - - def puppet_arm_publish(self, left, right): - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 - joint_state_msg.position = left - self.puppet_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = right - self.puppet_arm_right_publisher.publish(joint_state_msg) - - def robot_base_publish(self, vel): - vel_msg = Twist() - vel_msg.linear.x = vel[0] - vel_msg.linear.y = 0 - vel_msg.linear.z = 0 - vel_msg.angular.x = 0 - vel_msg.angular.y = 0 - vel_msg.angular.z = vel[1] - self.robot_base_publisher.publish(vel_msg) - - def puppet_arm_publish_continuous(self, left, right): - rate = rospy.Rate(self.args.publish_rate) - left_arm = None - right_arm = None - while True and not rospy.is_shutdown(): - if len(self.puppet_arm_left_deque) != 0: - left_arm = list(self.puppet_arm_left_deque[-1].position) - if len(self.puppet_arm_right_deque) != 0: - right_arm = list(self.puppet_arm_right_deque[-1].position) - if left_arm is None or right_arm is None: - rate.sleep() - continue - else: - break - left_symbol = [1 if left[i] - left_arm[i] > 0 else -1 for i in range(len(left))] - right_symbol = [1 if right[i] - right_arm[i] > 0 else -1 for i in range(len(right))] - flag = True - step = 0 - while flag and not rospy.is_shutdown(): - if self.puppet_arm_publish_lock.acquire(False): - return - left_diff = [abs(left[i] - left_arm[i]) for i in range(len(left))] - right_diff = [abs(right[i] - right_arm[i]) for i in range(len(right))] - flag = False - for i in range(len(left)): - if left_diff[i] < self.args.arm_steps_length[i]: - left_arm[i] = left[i] - else: - left_arm[i] += left_symbol[i] * self.args.arm_steps_length[i] - flag = True - for i in range(len(right)): - if right_diff[i] < self.args.arm_steps_length[i]: - right_arm[i] = right[i] - else: - right_arm[i] += right_symbol[i] * self.args.arm_steps_length[i] - flag = True - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 - joint_state_msg.position = left_arm - self.puppet_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = right_arm - self.puppet_arm_right_publisher.publish(joint_state_msg) - step += 1 - print("puppet_arm_publish_continuous:", step) - rate.sleep() - - def puppet_arm_publish_linear(self, left, right): - num_step = 100 - rate = rospy.Rate(200) - - left_arm = None - right_arm = None - - while True and not rospy.is_shutdown(): - if len(self.puppet_arm_left_deque) != 0: - left_arm = list(self.puppet_arm_left_deque[-1].position) - if len(self.puppet_arm_right_deque) != 0: - right_arm = list(self.puppet_arm_right_deque[-1].position) - if left_arm is None or right_arm is None: - rate.sleep() - continue - else: - break - - traj_left_list = np.linspace(left_arm, left, num_step) - traj_right_list = np.linspace(right_arm, right, num_step) - - for i in range(len(traj_left_list)): - traj_left = traj_left_list[i] - traj_right = traj_right_list[i] - traj_left[-1] = left[-1] - traj_right[-1] = right[-1] - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.header.stamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', 'joint6'] # 设置关节名称 - joint_state_msg.position = traj_left - self.puppet_arm_left_publisher.publish(joint_state_msg) - joint_state_msg.position = traj_right - self.puppet_arm_right_publisher.publish(joint_state_msg) - rate.sleep() - - def puppet_arm_publish_continuous_thread(self, left, right): - if self.puppet_arm_publish_thread is not None: - self.puppet_arm_publish_lock.release() - self.puppet_arm_publish_thread.join() - self.puppet_arm_publish_lock.acquire(False) - self.puppet_arm_publish_thread = None - self.puppet_arm_publish_thread = threading.Thread(target=self.puppet_arm_publish_continuous, args=(left, right)) - self.puppet_arm_publish_thread.start() - - def get_frame(self): - if len(self.img_left_deque) == 0 or len(self.img_right_deque) == 0 or len(self.img_front_deque) == 0 or \ - (self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or len(self.img_right_depth_deque) == 0 or len(self.img_front_depth_deque) == 0)): - return False - if self.args.use_depth_image: - frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec(), - self.img_left_depth_deque[-1].header.stamp.to_sec(), self.img_right_depth_deque[-1].header.stamp.to_sec(), self.img_front_depth_deque[-1].header.stamp.to_sec()]) - else: - frame_time = min([self.img_left_deque[-1].header.stamp.to_sec(), self.img_right_deque[-1].header.stamp.to_sec(), self.img_front_deque[-1].header.stamp.to_sec()]) - - if len(self.img_left_deque) == 0 or self.img_left_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.img_right_deque) == 0 or self.img_right_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.img_front_deque) == 0 or self.img_front_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.puppet_arm_left_deque) == 0 or self.puppet_arm_left_deque[-1].header.stamp.to_sec() < frame_time: - return False - if len(self.puppet_arm_right_deque) == 0 or self.puppet_arm_right_deque[-1].header.stamp.to_sec() < frame_time: - return False - if self.args.use_depth_image and (len(self.img_left_depth_deque) == 0 or self.img_left_depth_deque[-1].header.stamp.to_sec() < frame_time): - return False - if self.args.use_depth_image and (len(self.img_right_depth_deque) == 0 or self.img_right_depth_deque[-1].header.stamp.to_sec() < frame_time): - return False - if self.args.use_depth_image and (len(self.img_front_depth_deque) == 0 or self.img_front_depth_deque[-1].header.stamp.to_sec() < frame_time): - return False - if self.args.use_robot_base and (len(self.robot_base_deque) == 0 or self.robot_base_deque[-1].header.stamp.to_sec() < frame_time): - return False - - while self.img_left_deque[0].header.stamp.to_sec() < frame_time: - self.img_left_deque.popleft() - img_left = self.bridge.imgmsg_to_cv2(self.img_left_deque.popleft(), 'passthrough') - - while self.img_right_deque[0].header.stamp.to_sec() < frame_time: - self.img_right_deque.popleft() - img_right = self.bridge.imgmsg_to_cv2(self.img_right_deque.popleft(), 'passthrough') - - while self.img_front_deque[0].header.stamp.to_sec() < frame_time: - self.img_front_deque.popleft() - img_front = self.bridge.imgmsg_to_cv2(self.img_front_deque.popleft(), 'passthrough') - - while self.puppet_arm_left_deque[0].header.stamp.to_sec() < frame_time: - self.puppet_arm_left_deque.popleft() - puppet_arm_left = self.puppet_arm_left_deque.popleft() - - while self.puppet_arm_right_deque[0].header.stamp.to_sec() < frame_time: - self.puppet_arm_right_deque.popleft() - puppet_arm_right = self.puppet_arm_right_deque.popleft() - - img_left_depth = None - if self.args.use_depth_image: - while self.img_left_depth_deque[0].header.stamp.to_sec() < frame_time: - self.img_left_depth_deque.popleft() - img_left_depth = self.bridge.imgmsg_to_cv2(self.img_left_depth_deque.popleft(), 'passthrough') - - img_right_depth = None - if self.args.use_depth_image: - while self.img_right_depth_deque[0].header.stamp.to_sec() < frame_time: - self.img_right_depth_deque.popleft() - img_right_depth = self.bridge.imgmsg_to_cv2(self.img_right_depth_deque.popleft(), 'passthrough') - - img_front_depth = None - if self.args.use_depth_image: - while self.img_front_depth_deque[0].header.stamp.to_sec() < frame_time: - self.img_front_depth_deque.popleft() - img_front_depth = self.bridge.imgmsg_to_cv2(self.img_front_depth_deque.popleft(), 'passthrough') - - robot_base = None - if self.args.use_robot_base: - while self.robot_base_deque[0].header.stamp.to_sec() < frame_time: - self.robot_base_deque.popleft() - robot_base = self.robot_base_deque.popleft() - - return (img_front, img_left, img_right, img_front_depth, img_left_depth, img_right_depth, - puppet_arm_left, puppet_arm_right, robot_base) - - def img_left_callback(self, msg): - if len(self.img_left_deque) >= 2000: - self.img_left_deque.popleft() - self.img_left_deque.append(msg) - - def img_right_callback(self, msg): - if len(self.img_right_deque) >= 2000: - self.img_right_deque.popleft() - self.img_right_deque.append(msg) - - def img_front_callback(self, msg): - if len(self.img_front_deque) >= 2000: - self.img_front_deque.popleft() - self.img_front_deque.append(msg) - - def img_left_depth_callback(self, msg): - if len(self.img_left_depth_deque) >= 2000: - self.img_left_depth_deque.popleft() - self.img_left_depth_deque.append(msg) - - def img_right_depth_callback(self, msg): - if len(self.img_right_depth_deque) >= 2000: - self.img_right_depth_deque.popleft() - self.img_right_depth_deque.append(msg) - - def img_front_depth_callback(self, msg): - if len(self.img_front_depth_deque) >= 2000: - self.img_front_depth_deque.popleft() - self.img_front_depth_deque.append(msg) - - def puppet_arm_left_callback(self, msg): - if len(self.puppet_arm_left_deque) >= 2000: - self.puppet_arm_left_deque.popleft() - self.puppet_arm_left_deque.append(msg) - - def puppet_arm_right_callback(self, msg): - if len(self.puppet_arm_right_deque) >= 2000: - self.puppet_arm_right_deque.popleft() - self.puppet_arm_right_deque.append(msg) - - def robot_base_callback(self, msg): - if len(self.robot_base_deque) >= 2000: - self.robot_base_deque.popleft() - self.robot_base_deque.append(msg) - - def ctrl_callback(self, msg): - self.ctrl_state_lock.acquire() - self.ctrl_state = msg.data - self.ctrl_state_lock.release() - - def get_ctrl_state(self): - self.ctrl_state_lock.acquire() - state = self.ctrl_state - self.ctrl_state_lock.release() - return state - - def init_ros(self): - rospy.init_node('joint_state_publisher', anonymous=True) - rospy.Subscriber(self.args.img_left_topic, Image, self.img_left_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_right_topic, Image, self.img_right_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_front_topic, Image, self.img_front_callback, queue_size=1000, tcp_nodelay=True) - if self.args.use_depth_image: - rospy.Subscriber(self.args.img_left_depth_topic, Image, self.img_left_depth_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_right_depth_topic, Image, self.img_right_depth_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.img_front_depth_topic, Image, self.img_front_depth_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.puppet_arm_left_topic, JointState, self.puppet_arm_left_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.puppet_arm_right_topic, JointState, self.puppet_arm_right_callback, queue_size=1000, tcp_nodelay=True) - rospy.Subscriber(self.args.robot_base_topic, Odometry, self.robot_base_callback, queue_size=1000, tcp_nodelay=True) - self.puppet_arm_left_publisher = rospy.Publisher(self.args.puppet_arm_left_cmd_topic, JointState, queue_size=10) - self.puppet_arm_right_publisher = rospy.Publisher(self.args.puppet_arm_right_cmd_topic, JointState, queue_size=10) - self.robot_base_publisher = rospy.Publisher(self.args.robot_base_cmd_topic, Twist, queue_size=10) - - -def get_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True) - parser.add_argument('--task_name', action='store', type=str, help='task_name', default='aloha_mobile_dummy', required=False) - parser.add_argument('--max_publish_step', action='store', type=int, help='max_publish_step', default=10000, required=False) - parser.add_argument('--ckpt_name', action='store', type=str, help='ckpt_name', default='policy_best.ckpt', required=False) - parser.add_argument('--ckpt_stats_name', action='store', type=str, help='ckpt_stats_name', default='dataset_stats.pkl', required=False) - parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', default='ACT', required=False) - parser.add_argument('--batch_size', action='store', type=int, help='batch_size', default=8, required=False) - parser.add_argument('--seed', action='store', type=int, help='seed', default=0, required=False) - parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', default=2000, required=False) - parser.add_argument('--lr', action='store', type=float, help='lr', default=1e-5, required=False) - parser.add_argument('--weight_decay', type=float, help='weight_decay', default=1e-4, required=False) - parser.add_argument('--dilation', action='store_true', - help="If true, we replace stride with dilation in the last convolutional block (DC5)", required=False) - parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), - help="Type of positional embedding to use on top of the image features", required=False) - parser.add_argument('--masks', action='store_true', - help="Train segmentation head if the flag is provided") - parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', default=10, required=False) - parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', default=512, required=False) - parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', default=3200, required=False) - parser.add_argument('--temporal_agg', action='store', type=bool, help='temporal_agg', default=True, required=False) - - parser.add_argument('--state_dim', action='store', type=int, help='state_dim', default=14, required=False) - parser.add_argument('--lr_backbone', action='store', type=float, help='lr_backbone', default=1e-5, required=False) - parser.add_argument('--backbone', action='store', type=str, help='backbone', default='resnet18', required=False) - parser.add_argument('--loss_function', action='store', type=str, help='loss_function l1 l2 l1+l2', default='l1', required=False) - parser.add_argument('--enc_layers', action='store', type=int, help='enc_layers', default=4, required=False) - parser.add_argument('--dec_layers', action='store', type=int, help='dec_layers', default=7, required=False) - parser.add_argument('--nheads', action='store', type=int, help='nheads', default=8, required=False) - parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer", required=False) - parser.add_argument('--pre_norm', action='store_true', required=False) - - parser.add_argument('--img_front_topic', action='store', type=str, help='img_front_topic', - default='/camera_f/color/image_raw', required=False) - parser.add_argument('--img_left_topic', action='store', type=str, help='img_left_topic', - default='/camera_l/color/image_raw', required=False) - parser.add_argument('--img_right_topic', action='store', type=str, help='img_right_topic', - default='/camera_r/color/image_raw', required=False) - - parser.add_argument('--img_front_depth_topic', action='store', type=str, help='img_front_depth_topic', - default='/camera_f/depth/image_raw', required=False) - parser.add_argument('--img_left_depth_topic', action='store', type=str, help='img_left_depth_topic', - default='/camera_l/depth/image_raw', required=False) - parser.add_argument('--img_right_depth_topic', action='store', type=str, help='img_right_depth_topic', - default='/camera_r/depth/image_raw', required=False) - - parser.add_argument('--puppet_arm_left_cmd_topic', action='store', type=str, help='puppet_arm_left_cmd_topic', - default='/master/joint_left', required=False) - parser.add_argument('--puppet_arm_right_cmd_topic', action='store', type=str, help='puppet_arm_right_cmd_topic', - default='/master/joint_right', required=False) - parser.add_argument('--puppet_arm_left_topic', action='store', type=str, help='puppet_arm_left_topic', - default='/puppet/joint_left', required=False) - parser.add_argument('--puppet_arm_right_topic', action='store', type=str, help='puppet_arm_right_topic', - default='/puppet/joint_right', required=False) - - parser.add_argument('--robot_base_topic', action='store', type=str, help='robot_base_topic', - default='/odom_raw', required=False) - parser.add_argument('--robot_base_cmd_topic', action='store', type=str, help='robot_base_topic', - default='/cmd_vel', required=False) - parser.add_argument('--use_robot_base', action='store', type=bool, help='use_robot_base', - default=False, required=False) - parser.add_argument('--publish_rate', action='store', type=int, help='publish_rate', - default=40, required=False) - parser.add_argument('--pos_lookahead_step', action='store', type=int, help='pos_lookahead_step', - default=0, required=False) - parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', - default=32, required=False) - parser.add_argument('--arm_steps_length', action='store', type=float, help='arm_steps_length', - default=[0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.2], required=False) - - parser.add_argument('--use_actions_interpolation', action='store', type=bool, help='use_actions_interpolation', - default=False, required=False) - parser.add_argument('--use_depth_image', action='store', type=bool, help='use_depth_image', - default=False, required=False) - - # for Diffusion - parser.add_argument('--observation_horizon', action='store', type=int, help='observation_horizon', default=1, required=False) - parser.add_argument('--action_horizon', action='store', type=int, help='action_horizon', default=8, required=False) - parser.add_argument('--num_inference_timesteps', action='store', type=int, help='num_inference_timesteps', default=10, required=False) - parser.add_argument('--ema_power', action='store', type=int, help='ema_power', default=0.75, required=False) - args = parser.parse_args() - return args - - -def main(): - args = get_arguments() - ros_operator = RosOperator(args) - config = get_model_config(args) - model_inference(args, config, ros_operator, save_episode=True) - - -if __name__ == '__main__': - main() -# python act/inference.py --ckpt_dir ~/train0314/ \ No newline at end of file diff --git a/lerobot_aloha/main.py b/lerobot_aloha/main.py new file mode 100644 index 0000000..0dc0290 --- /dev/null +++ b/lerobot_aloha/main.py @@ -0,0 +1,56 @@ +import argparse +from common.rosrobot_factory import RobotFactory +from common.utils.data_utils import record +from common.utils.replay_utils import replay + + +def get_arguments(): + """ + Parse command line arguments. + + Returns: + argparse.Namespace: Parsed arguments + """ + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.fps = 30 + args.resume = False + args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right" + args.root = "./data5" + args.episode = 0 # replay episode + args.num_image_writer_processes = 0 + args.num_image_writer_threads_per_camera = 4 + args.video = True + args.num_episodes = 100 + args.episode_time_s = 30000 + args.play_sounds = False + args.display_cameras = True + args.single_task = "move the bottle from the right to the scale right" + args.use_depth_image = False + args.use_base = False + args.push_to_hub = False + args.policy = None + args.control_type = "record" + return args + + +def control_robot(cfg): + """ + Control robot based on configuration. + + Args: + cfg: Configuration object + """ + # Create robot instance using factory pattern + robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=cfg) + + # Execute appropriate control mode + if cfg.control_type == "record": + record(robot, cfg) + elif cfg.control_type == "replay": + replay(robot, cfg) + + +if __name__ == "__main__": + cfg = get_arguments() + control_robot(cfg) diff --git a/lerobot_aloha/replay_data.py b/lerobot_aloha/replay_data.py index 6c880dc..8b29a2c 100644 --- a/lerobot_aloha/replay_data.py +++ b/lerobot_aloha/replay_data.py @@ -1,112 +1,36 @@ -#coding=utf-8 -import os -import numpy as np -import cv2 -import h5py +# coding=utf-8 import argparse import rospy - -from cv_bridge import CvBridge -from std_msgs.msg import Header -from sensor_msgs.msg import Image, JointState -from geometry_msgs.msg import Twist -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from common.rosrobot_factory import RobotFactory +from common.utils.replay_utils import replay - -def main(args): - rospy.init_node("replay_node") - bridge = CvBridge() - # img_left_publisher = rospy.Publisher(args.img_left_topic, Image, queue_size=10) - # img_right_publisher = rospy.Publisher(args.img_right_topic, Image, queue_size=10) - # img_front_publisher = rospy.Publisher(args.img_front_topic, Image, queue_size=10) +def get_arguments(): + """ + Parse command line arguments. - # puppet_arm_left_publisher = rospy.Publisher(args.puppet_arm_left_topic, JointState, queue_size=10) - # puppet_arm_right_publisher = rospy.Publisher(args.puppet_arm_right_topic, JointState, queue_size=10) - - master_arm_left_publisher = rospy.Publisher(args.master_arm_left_topic, JointState, queue_size=10) - master_arm_right_publisher = rospy.Publisher(args.master_arm_right_topic, JointState, queue_size=10) - - # robot_base_publisher = rospy.Publisher(args.robot_base_topic, Twist, queue_size=10) - - - # dataset_dir = args.dataset_dir - # episode_idx = args.episode_idx - # task_name = args.task_name - # dataset_name = f'episode_{episode_idx}' - - dataset = LeRobotDataset(args.repo_id, root=args.root, episodes=[args.episode]) - actions = dataset.hf_dataset.select_columns("action") - velocitys = dataset.hf_dataset.select_columns("observation.velocity") - efforts = dataset.hf_dataset.select_columns("observation.effort") - - origin_left = [-0.0057,-0.031, -0.0122, -0.032, 0.0099, 0.0179, 0.2279] - origin_right = [ 0.0616, 0.0021, 0.0475, -0.1013, 0.1097, 0.0872, 0.2279] - - joint_state_msg = JointState() - joint_state_msg.header = Header() - joint_state_msg.name = ['joint0', 'joint1', 'joint2', 'joint3', 'joint4', 'joint5', ''] # 设置关节名称 - twist_msg = Twist() - - rate = rospy.Rate(args.fps) - - # qposs, qvels, efforts, actions, base_actions, image_dicts = load_hdf5(os.path.join(dataset_dir, task_name), dataset_name) - - - last_action = [-0.00019073486328125, 0.00934600830078125, 0.01354217529296875, -0.01049041748046875, -0.00057220458984375, -0.00057220458984375, -0.00526118278503418, -0.00095367431640625, 0.00705718994140625, 0.01239776611328125, -0.00705718994140625, -0.00019073486328125, -0.00057220458984375, -0.009171326644718647] - last_velocity = [-0.010990142822265625, -0.010990142822265625, -0.03296661376953125, 0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] - last_effort = [-0.021978378295898438, 0.2417583465576172, 4.320878982543945, 3.6527481079101562, -0.013187408447265625, -0.013187408447265625, 0.0, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.010990142822265625, -0.010990142822265625, -0.03296661376953125, -0.03296661376953125] - rate = rospy.Rate(50) - for idx in range(len(actions)): - action = actions[idx]['action'].detach().cpu().numpy() - velocity = velocitys[idx]['observation.velocity'].detach().cpu().numpy() - effort = efforts[idx]['observation.effort'].detach().cpu().numpy() - if(rospy.is_shutdown()): - break - - new_actions = np.linspace(last_action, action, 5) # 插值 - new_velocitys = np.linspace(last_velocity, velocity, 5) # 插值 - new_efforts = np.linspace(last_effort, effort, 5) # 插值 - last_action = action - last_velocity = velocity - last_effort = effort - for act in new_actions: - print(np.round(act[:7], 4)) - cur_timestamp = rospy.Time.now() # 设置时间戳 - joint_state_msg.header.stamp = cur_timestamp - - joint_state_msg.position = act[:7] - joint_state_msg.velocity = last_velocity[:7] - joint_state_msg.effort = last_effort[:7] - master_arm_left_publisher.publish(joint_state_msg) - - joint_state_msg.position = act[7:] - joint_state_msg.velocity = last_velocity[:7] - joint_state_msg.effort = last_effort[7:] - master_arm_right_publisher.publish(joint_state_msg) - - if(rospy.is_shutdown()): - break - rate.sleep() - - - - -if __name__ == '__main__': + Returns: + argparse.Namespace: Parsed arguments + """ parser = argparse.ArgumentParser() - # parser.add_argument('--master_arm_left_topic', action='store', type=str, help='master_arm_left_topic', - # default='/master/joint_left', required=False) - # parser.add_argument('--master_arm_right_topic', action='store', type=str, help='master_arm_right_topic', - # default='/master/joint_right', required=False) - - args = parser.parse_args() args.repo_id = "tangger/test" args.root = "/home/ubuntu/LYT/aloha_lerobot/data1" - args.episode = 1 # replay episode - args.master_arm_left_topic = "/master/joint_left" - args.master_arm_right_topic = "/master/joint_right" + args.episode = 1 # replay episode args.fps = 30 + args.use_depth_image = False + args.use_base = False + return args - main(args) - # python collect_data.py --max_timesteps 500 --is_compress --episode_idx 0 \ No newline at end of file + +if __name__ == '__main__': + args = get_arguments() + + # Initialize ROS node + rospy.init_node("replay_node") + + # Create robot instance using factory pattern + robot = RobotFactory.create(config_file="/home/ubuntu/LYT/lerobot_aloha/lerobot_aloha/configs/agilex.yaml", args=args) + + # Replay the specified episode + replay(robot, args)