init repo

This commit is contained in:
2025-04-05 21:46:49 +08:00
parent 4b58b22868
commit 91c2b7b0cb
17 changed files with 2473 additions and 0 deletions

View File

@@ -0,0 +1,487 @@
import logging
import time
from dataclasses import asdict
from pprint import pformat
from pprint import pprint
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,
RecordControlConfig,
RemoteRobotConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.common.robot_devices.control_utils import (
# init_keyboard_listener,
record_episode,
stop_recording,
is_headless
)
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.common.utils.utils import get_safe_torch_device
from contextlib import nullcontext
from copy import copy
import torch
import rospy
import cv2
from lerobot.configs import parser
from agilex_robot import AgilexRobot
########################################################################################
# Control modes
########################################################################################
def predict_action(observation, policy, device, use_amp):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if "image" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0)
observation[name] = observation[name].to(device)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def control_loop(
robot,
control_time_s=None,
teleoperate=False,
display_cameras=False,
dataset: LeRobotDataset | None = None,
events=None,
policy = None,
fps: int | None = None,
single_task: str | None = None,
):
# TODO(rcadene): Add option to record logs
# if not robot.is_connected:
# robot.connect()
if events is None:
events = {"exit_early": False}
if control_time_s is None:
control_time_s = float("inf")
if dataset is not None and single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
rate = rospy.Rate(fps)
print_flag = True
while timestamp < control_time_s and not rospy.is_shutdown():
# print(timestamp < control_time_s)
# print(rospy.is_shutdown())
start_loop_t = time.perf_counter()
if teleoperate:
observation, action = robot.teleop_step()
if observation is None or action is None:
if print_flag:
print("sync data fail, retrying...\n")
print_flag = False
rate.sleep()
continue
else:
# pass
observation = robot.capture_observation()
if policy is not None:
pred_action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
action = robot.send_action(pred_action)
action = {"action": action}
if dataset is not None:
frame = {**observation, **action, "task": single_task}
dataset.add_frame(frame)
# if display_cameras and not is_headless():
# image_keys = [key for key in observation if "image" in key]
# for key in image_keys:
# if "depth" in key:
# pass
# else:
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
# print(1)
# cv2.waitKey(1)
if display_cameras and not is_headless():
image_keys = [key for key in observation if "image" in key]
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080可以根据实际情况调整
screen_width = 1920
screen_height = 1080
# 计算窗口的排列方式
num_images = len(image_keys)
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
columns = min(num_images, max_columns) # 实际使用的列数
# 遍历所有图像键并显示
for idx, key in enumerate(image_keys):
if "depth" in key:
continue # 跳过深度图像
# 将图像从 RGB 转换为 BGR 格式
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
# 创建窗口
cv2.imshow(key, image)
# 计算窗口位置
window_width = 640
window_height = 480
row = idx // max_columns
col = idx % max_columns
x_position = col * window_width
y_position = row * window_height
# 移动窗口到指定位置
cv2.moveWindow(key, x_position, y_position)
# 等待 1 毫秒以处理事件
cv2.waitKey(1)
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
def init_keyboard_listener():
# Allow to exit early while recording an episode or resetting the environment,
# by tapping the right arrow key '->'. This might require a sudo permission
# to allow your terminal to monitor keyboard events.
events = {}
events["exit_early"] = False
events["record_start"] = False
events["rerecord_episode"] = False
events["stop_recording"] = False
if is_headless():
logging.warning(
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
)
listener = None
return listener, events
# Only import pynput if not in a headless environment
from pynput import keyboard
def on_press(key):
try:
if key == keyboard.Key.right:
print("Right arrow key pressed. Exiting loop...")
events["exit_early"] = True
events["record_start"] = False
elif key == keyboard.Key.left:
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
events["rerecord_episode"] = True
events["exit_early"] = True
elif key == keyboard.Key.esc:
print("Escape key pressed. Stopping data recording...")
events["stop_recording"] = True
events["exit_early"] = True
elif key == keyboard.Key.up:
print("Up arrow pressed. Start data recording...")
events["record_start"] = True
except Exception as e:
print(f"Error handling key press: {e}")
listener = keyboard.Listener(on_press=on_press)
listener.start()
return listener, events
def stop_recording(robot, listener, display_cameras):
if not is_headless():
if listener is not None:
listener.stop()
if display_cameras:
cv2.destroyAllWindows()
def record_episode(
robot,
dataset,
events,
episode_time_s,
display_cameras,
policy,
fps,
single_task,
):
control_loop(
robot=robot,
control_time_s=episode_time_s,
display_cameras=display_cameras,
dataset=dataset,
events=events,
policy=policy,
fps=fps,
teleoperate=policy is None,
single_task=single_task,
)
def record(
robot,
cfg
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
# Create empty dataset or load existing saved episodes
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# policy = None
# if not robot.is_connected:
# robot.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", cfg.play_sounds)
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
# if has_method(robot, "teleop_safety_stop"):
# robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
# if events["record_start"]:
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Current code logic doesn't allow to teleoperate during this time.
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset
def replay(
robot: AgilexRobot,
cfg,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
actions = dataset.hf_dataset.select_columns("action")
# if not robot.is_connected:
# robot.connect()
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
action = actions[idx]["action"]
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
# log_control_info(robot, dt_s, fps=cfg.fps)
import argparse
def get_arguments():
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.fps = 30
args.resume = False
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
args.episode = 0 # replay episode
args.num_image_writer_processes = 0
args.num_image_writer_threads_per_camera = 4
args.video = True
args.num_episodes = 100
args.episode_time_s = 30000
args.play_sounds = False
args.display_cameras = True
args.single_task = "move the bottle from the right to the scale right"
args.use_depth_image = False
args.use_base = False
args.push_to_hub = False
args.policy = None
# args.teleoprate = True
args.control_type = "record"
# args.control_type = "replay"
return args
# @parser.wrap()
# def control_robot(cfg: ControlPipelineConfig):
# init_logging()
# logging.info(pformat(asdict(cfg)))
# # robot = make_robot_from_config(cfg.robot)
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# if isinstance(cfg.control, RecordControlConfig):
# print(cfg.control)
# record(robot, cfg.control)
# elif isinstance(cfg.control, ReplayControlConfig):
# replay(robot, cfg.control)
# # if robot.is_connected:
# # # Disconnect manually to avoid a "Core dump" during process
# # # termination due to camera threads not properly exiting.
# # robot.disconnect()
# @parser.wrap()
def control_robot(cfg):
# robot = make_robot_from_config(cfg.robot)
from agilex_robot import AgilexRobot
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
if cfg.control_type == "record":
record(robot, cfg)
elif cfg.control_type == "replay":
replay(robot, cfg)
# if robot.is_connected:
# # Disconnect manually to avoid a "Core dump" during process
# # termination due to camera threads not properly exiting.
# robot.disconnect()
if __name__ == "__main__":
cfg = get_arguments()
control_robot(cfg)
# control_robot()
# cfg = get_arguments()
# from agilex_robot import AgilexRobot
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
# print(robot.features.items())
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
# record(robot, cfg)
# capture = robot.capture_observation()
# import torch
# torch.save(capture, "test.pt")
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
# device='cpu')
# robot.send_action(action.squeeze(0))
# print()