Files
lerobot_aloha/collect_data/collect_data_lerobot.py
2025-04-05 21:46:49 +08:00

487 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
from dataclasses import asdict
from pprint import pformat
from pprint import pprint
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.common.robot_devices.control_utils import (
# init_keyboard_listener,
record_episode,
stop_recording,
is_headless
)
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.common.utils.utils import get_safe_torch_device
from contextlib import nullcontext
from copy import copy
import torch
import rospy
import cv2
from lerobot.configs import parser
from agilex_robot import AgilexRobot
########################################################################################
# Control modes
########################################################################################
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
# if not robot.is_connected:
# robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
rate = rospy.Rate(fps)
print_flag = True
while timestamp < control_time_s and not rospy.is_shutdown():
# print(timestamp < control_time_s)
# print(rospy.is_shutdown())
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step()
if observation is None or action is None:
if print_flag:
print("sync data fail, retrying...\n")
print_flag = False
rate.sleep()
continue
else:
# pass
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# if display_cameras and not is_headless():
# image_keys = [key for key in observation if "image" in key]
# for key in image_keys:
# if "depth" in key:
# pass
# else:
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
# print(1)
# cv2.waitKey(1)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080可以根据实际情况调整
screen_width = 1920
screen_height = 1080
# 计算窗口的排列方式
num_images = len(image_keys)
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
columns = min(num_images, max_columns) # 实际使用的列数
# 遍历所有图像键并显示
for idx, key in enumerate(image_keys):
if "depth" in key:
continue # 跳过深度图像
# 将图像从 RGB 转换为 BGR 格式
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
# 创建窗口
cv2.imshow(key, image)
# 计算窗口位置
window_width = 640
window_height = 480
row = idx // max_columns
col = idx % max_columns
x_position = col * window_width
y_position = row * window_height
# 移动窗口到指定位置
cv2.moveWindow(key, x_position, y_position)
# 等待 1 毫秒以处理事件
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["record_start"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
events["record_start"] = False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.up:
print("Up arrow pressed. Start data recording...")
events["record_start"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def stop_recording(robot, listener, display_cameras):
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
fps,
single_task,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
def record(
robot,
cfg
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
# Create empty dataset or load existing saved episodes
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# policy = None
# if not robot.is_connected:
# robot.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
# if has_method(robot, "teleop_safety_stop"):
# robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
# if events["record_start"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset
def replay(
robot: AgilexRobot,
cfg,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
# if not robot.is_connected:
# robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
# log_control_info(robot, dt_s, fps=cfg.fps)
import argparse
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 100
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "move the bottle from the right to the scale right"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy = None
# args.teleoprate = True
args.control_type = "record"
# args.control_type = "replay"
return args
# @parser.wrap()
# def control_robot(cfg: ControlPipelineConfig):
# init_logging()
# logging.info(pformat(asdict(cfg)))
# # robot = make_robot_from_config(cfg.robot)
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# if isinstance(cfg.control, RecordControlConfig):
# print(cfg.control)
# record(robot, cfg.control)
# elif isinstance(cfg.control, ReplayControlConfig):
# replay(robot, cfg.control)
# # if robot.is_connected:
# # # Disconnect manually to avoid a "Core dump" during process
# # # termination due to camera threads not properly exiting.
# # robot.disconnect()
# @parser.wrap()
def control_robot(cfg):
# robot = make_robot_from_config(cfg.robot)
from agilex_robot import AgilexRobot
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
# if robot.is_connected:
# # Disconnect manually to avoid a "Core dump" during process
# # termination due to camera threads not properly exiting.
# robot.disconnect()
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# cfg = get_arguments()
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)
# capture = robot.capture_observation()
# import torch
# torch.save(capture, "test.pt")
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
# device='cpu')
# robot.send_action(action.squeeze(0))
# print()