487 lines
17 KiB
Python
487 lines
17 KiB
Python
import logging
|
||
import time
|
||
from dataclasses import asdict
|
||
from pprint import pformat
|
||
from pprint import pprint
|
||
|
||
# from safetensors.torch import load_file, save_file
|
||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||
from lerobot.common.policies.factory import make_policy
|
||
from lerobot.common.robot_devices.control_configs import (
|
||
CalibrateControlConfig,
|
||
ControlPipelineConfig,
|
||
RecordControlConfig,
|
||
RemoteRobotConfig,
|
||
ReplayControlConfig,
|
||
TeleoperateControlConfig,
|
||
)
|
||
from lerobot.common.robot_devices.control_utils import (
|
||
# init_keyboard_listener,
|
||
record_episode,
|
||
stop_recording,
|
||
is_headless
|
||
)
|
||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||
from lerobot.common.utils.utils import get_safe_torch_device
|
||
from contextlib import nullcontext
|
||
from copy import copy
|
||
import torch
|
||
import rospy
|
||
import cv2
|
||
from lerobot.configs import parser
|
||
from agilex_robot import AgilexRobot
|
||
|
||
|
||
########################################################################################
|
||
# Control modes
|
||
########################################################################################
|
||
|
||
|
||
def predict_action(observation, policy, device, use_amp):
|
||
observation = copy(observation)
|
||
with (
|
||
torch.inference_mode(),
|
||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||
):
|
||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||
for name in observation:
|
||
if "image" in name:
|
||
observation[name] = observation[name].type(torch.float32) / 255
|
||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||
observation[name] = observation[name].unsqueeze(0)
|
||
observation[name] = observation[name].to(device)
|
||
|
||
# Compute the next action with the policy
|
||
# based on the current observation
|
||
action = policy.select_action(observation)
|
||
|
||
# Remove batch dimension
|
||
action = action.squeeze(0)
|
||
|
||
# Move to cpu, if not already the case
|
||
action = action.to("cpu")
|
||
|
||
return action
|
||
|
||
def control_loop(
|
||
robot,
|
||
control_time_s=None,
|
||
teleoperate=False,
|
||
display_cameras=False,
|
||
dataset: LeRobotDataset | None = None,
|
||
events=None,
|
||
policy = None,
|
||
fps: int | None = None,
|
||
single_task: str | None = None,
|
||
):
|
||
# TODO(rcadene): Add option to record logs
|
||
# if not robot.is_connected:
|
||
# robot.connect()
|
||
|
||
if events is None:
|
||
events = {"exit_early": False}
|
||
|
||
if control_time_s is None:
|
||
control_time_s = float("inf")
|
||
|
||
if dataset is not None and single_task is None:
|
||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||
|
||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||
|
||
timestamp = 0
|
||
start_episode_t = time.perf_counter()
|
||
rate = rospy.Rate(fps)
|
||
print_flag = True
|
||
while timestamp < control_time_s and not rospy.is_shutdown():
|
||
# print(timestamp < control_time_s)
|
||
# print(rospy.is_shutdown())
|
||
start_loop_t = time.perf_counter()
|
||
|
||
if teleoperate:
|
||
observation, action = robot.teleop_step()
|
||
if observation is None or action is None:
|
||
if print_flag:
|
||
print("sync data fail, retrying...\n")
|
||
print_flag = False
|
||
rate.sleep()
|
||
continue
|
||
else:
|
||
# pass
|
||
observation = robot.capture_observation()
|
||
if policy is not None:
|
||
pred_action = predict_action(
|
||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||
)
|
||
# Action can eventually be clipped using `max_relative_target`,
|
||
# so action actually sent is saved in the dataset.
|
||
action = robot.send_action(pred_action)
|
||
action = {"action": action}
|
||
|
||
if dataset is not None:
|
||
frame = {**observation, **action, "task": single_task}
|
||
dataset.add_frame(frame)
|
||
|
||
# if display_cameras and not is_headless():
|
||
# image_keys = [key for key in observation if "image" in key]
|
||
# for key in image_keys:
|
||
# if "depth" in key:
|
||
# pass
|
||
# else:
|
||
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||
|
||
# print(1)
|
||
# cv2.waitKey(1)
|
||
|
||
if display_cameras and not is_headless():
|
||
image_keys = [key for key in observation if "image" in key]
|
||
|
||
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整)
|
||
screen_width = 1920
|
||
screen_height = 1080
|
||
|
||
# 计算窗口的排列方式
|
||
num_images = len(image_keys)
|
||
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
|
||
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
|
||
columns = min(num_images, max_columns) # 实际使用的列数
|
||
|
||
# 遍历所有图像键并显示
|
||
for idx, key in enumerate(image_keys):
|
||
if "depth" in key:
|
||
continue # 跳过深度图像
|
||
|
||
# 将图像从 RGB 转换为 BGR 格式
|
||
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||
|
||
# 创建窗口
|
||
cv2.imshow(key, image)
|
||
|
||
# 计算窗口位置
|
||
window_width = 640
|
||
window_height = 480
|
||
row = idx // max_columns
|
||
col = idx % max_columns
|
||
x_position = col * window_width
|
||
y_position = row * window_height
|
||
|
||
# 移动窗口到指定位置
|
||
cv2.moveWindow(key, x_position, y_position)
|
||
|
||
# 等待 1 毫秒以处理事件
|
||
cv2.waitKey(1)
|
||
|
||
if fps is not None:
|
||
dt_s = time.perf_counter() - start_loop_t
|
||
busy_wait(1 / fps - dt_s)
|
||
|
||
dt_s = time.perf_counter() - start_loop_t
|
||
# log_control_info(robot, dt_s, fps=fps)
|
||
|
||
timestamp = time.perf_counter() - start_episode_t
|
||
if events["exit_early"]:
|
||
events["exit_early"] = False
|
||
break
|
||
|
||
|
||
def init_keyboard_listener():
|
||
# Allow to exit early while recording an episode or resetting the environment,
|
||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||
# to allow your terminal to monitor keyboard events.
|
||
events = {}
|
||
events["exit_early"] = False
|
||
events["record_start"] = False
|
||
events["rerecord_episode"] = False
|
||
events["stop_recording"] = False
|
||
|
||
if is_headless():
|
||
logging.warning(
|
||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||
)
|
||
listener = None
|
||
return listener, events
|
||
|
||
# Only import pynput if not in a headless environment
|
||
from pynput import keyboard
|
||
|
||
def on_press(key):
|
||
try:
|
||
if key == keyboard.Key.right:
|
||
print("Right arrow key pressed. Exiting loop...")
|
||
events["exit_early"] = True
|
||
events["record_start"] = False
|
||
elif key == keyboard.Key.left:
|
||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||
events["rerecord_episode"] = True
|
||
events["exit_early"] = True
|
||
elif key == keyboard.Key.esc:
|
||
print("Escape key pressed. Stopping data recording...")
|
||
events["stop_recording"] = True
|
||
events["exit_early"] = True
|
||
elif key == keyboard.Key.up:
|
||
print("Up arrow pressed. Start data recording...")
|
||
events["record_start"] = True
|
||
|
||
|
||
except Exception as e:
|
||
print(f"Error handling key press: {e}")
|
||
|
||
listener = keyboard.Listener(on_press=on_press)
|
||
listener.start()
|
||
|
||
return listener, events
|
||
|
||
|
||
def stop_recording(robot, listener, display_cameras):
|
||
|
||
if not is_headless():
|
||
if listener is not None:
|
||
listener.stop()
|
||
|
||
if display_cameras:
|
||
cv2.destroyAllWindows()
|
||
|
||
|
||
def record_episode(
|
||
robot,
|
||
dataset,
|
||
events,
|
||
episode_time_s,
|
||
display_cameras,
|
||
policy,
|
||
fps,
|
||
single_task,
|
||
):
|
||
control_loop(
|
||
robot=robot,
|
||
control_time_s=episode_time_s,
|
||
display_cameras=display_cameras,
|
||
dataset=dataset,
|
||
events=events,
|
||
policy=policy,
|
||
fps=fps,
|
||
teleoperate=policy is None,
|
||
single_task=single_task,
|
||
)
|
||
|
||
|
||
def record(
|
||
robot,
|
||
cfg
|
||
) -> LeRobotDataset:
|
||
# TODO(rcadene): Add option to record logs
|
||
if cfg.resume:
|
||
dataset = LeRobotDataset(
|
||
cfg.repo_id,
|
||
root=cfg.root,
|
||
)
|
||
if len(robot.cameras) > 0:
|
||
dataset.start_image_writer(
|
||
num_processes=cfg.num_image_writer_processes,
|
||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||
)
|
||
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
||
else:
|
||
# Create empty dataset or load existing saved episodes
|
||
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
|
||
dataset = LeRobotDataset.create(
|
||
cfg.repo_id,
|
||
cfg.fps,
|
||
root=cfg.root,
|
||
robot=None,
|
||
features=robot.features,
|
||
use_videos=cfg.video,
|
||
image_writer_processes=cfg.num_image_writer_processes,
|
||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||
)
|
||
|
||
# Load pretrained policy
|
||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||
# policy = None
|
||
|
||
# if not robot.is_connected:
|
||
# robot.connect()
|
||
|
||
listener, events = init_keyboard_listener()
|
||
|
||
# Execute a few seconds without recording to:
|
||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||
# 2. give times to the robot devices to connect and start synchronizing,
|
||
# 3. place the cameras windows on screen
|
||
enable_teleoperation = policy is None
|
||
log_say("Warmup record", cfg.play_sounds)
|
||
print()
|
||
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
||
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||
|
||
# if has_method(robot, "teleop_safety_stop"):
|
||
# robot.teleop_safety_stop()
|
||
|
||
recorded_episodes = 0
|
||
while True:
|
||
if recorded_episodes >= cfg.num_episodes:
|
||
break
|
||
|
||
# if events["record_start"]:
|
||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
||
record_episode(
|
||
robot=robot,
|
||
dataset=dataset,
|
||
events=events,
|
||
episode_time_s=cfg.episode_time_s,
|
||
display_cameras=cfg.display_cameras,
|
||
policy=policy,
|
||
fps=cfg.fps,
|
||
single_task=cfg.single_task,
|
||
)
|
||
|
||
# Execute a few seconds without recording to give time to manually reset the environment
|
||
# Current code logic doesn't allow to teleoperate during this time.
|
||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||
# Skip reset for the last episode to be recorded
|
||
if not events["stop_recording"] and (
|
||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||
):
|
||
log_say("Reset the environment", cfg.play_sounds)
|
||
pprint("Reset the environment, stop recording")
|
||
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
|
||
|
||
if events["rerecord_episode"]:
|
||
log_say("Re-record episode", cfg.play_sounds)
|
||
pprint("Re-record episode")
|
||
events["rerecord_episode"] = False
|
||
events["exit_early"] = False
|
||
dataset.clear_episode_buffer()
|
||
continue
|
||
|
||
dataset.save_episode()
|
||
recorded_episodes += 1
|
||
|
||
if events["stop_recording"]:
|
||
break
|
||
|
||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||
stop_recording(robot, listener, cfg.display_cameras)
|
||
|
||
if cfg.push_to_hub:
|
||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||
|
||
log_say("Exiting", cfg.play_sounds)
|
||
return dataset
|
||
|
||
|
||
def replay(
|
||
robot: AgilexRobot,
|
||
cfg,
|
||
):
|
||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||
# TODO(rcadene): Add option to record logs
|
||
|
||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||
actions = dataset.hf_dataset.select_columns("action")
|
||
|
||
# if not robot.is_connected:
|
||
# robot.connect()
|
||
|
||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||
for idx in range(dataset.num_frames):
|
||
start_episode_t = time.perf_counter()
|
||
|
||
action = actions[idx]["action"]
|
||
robot.send_action(action)
|
||
|
||
dt_s = time.perf_counter() - start_episode_t
|
||
busy_wait(1 / cfg.fps - dt_s)
|
||
|
||
dt_s = time.perf_counter() - start_episode_t
|
||
# log_control_info(robot, dt_s, fps=cfg.fps)
|
||
|
||
|
||
import argparse
|
||
def get_arguments():
|
||
parser = argparse.ArgumentParser()
|
||
args = parser.parse_args()
|
||
args.fps = 30
|
||
args.resume = False
|
||
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
|
||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
|
||
args.episode = 0 # replay episode
|
||
args.num_image_writer_processes = 0
|
||
args.num_image_writer_threads_per_camera = 4
|
||
args.video = True
|
||
args.num_episodes = 100
|
||
args.episode_time_s = 30000
|
||
args.play_sounds = False
|
||
args.display_cameras = True
|
||
args.single_task = "move the bottle from the right to the scale right"
|
||
args.use_depth_image = False
|
||
args.use_base = False
|
||
args.push_to_hub = False
|
||
args.policy = None
|
||
# args.teleoprate = True
|
||
args.control_type = "record"
|
||
# args.control_type = "replay"
|
||
return args
|
||
|
||
|
||
|
||
# @parser.wrap()
|
||
# def control_robot(cfg: ControlPipelineConfig):
|
||
# init_logging()
|
||
# logging.info(pformat(asdict(cfg)))
|
||
|
||
# # robot = make_robot_from_config(cfg.robot)
|
||
# from agilex_robot import AgilexRobot
|
||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||
|
||
# if isinstance(cfg.control, RecordControlConfig):
|
||
# print(cfg.control)
|
||
# record(robot, cfg.control)
|
||
# elif isinstance(cfg.control, ReplayControlConfig):
|
||
# replay(robot, cfg.control)
|
||
|
||
# # if robot.is_connected:
|
||
# # # Disconnect manually to avoid a "Core dump" during process
|
||
# # # termination due to camera threads not properly exiting.
|
||
# # robot.disconnect()
|
||
|
||
|
||
# @parser.wrap()
|
||
def control_robot(cfg):
|
||
|
||
# robot = make_robot_from_config(cfg.robot)
|
||
from agilex_robot import AgilexRobot
|
||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||
|
||
if cfg.control_type == "record":
|
||
record(robot, cfg)
|
||
elif cfg.control_type == "replay":
|
||
replay(robot, cfg)
|
||
|
||
# if robot.is_connected:
|
||
# # Disconnect manually to avoid a "Core dump" during process
|
||
# # termination due to camera threads not properly exiting.
|
||
# robot.disconnect()
|
||
|
||
if __name__ == "__main__":
|
||
cfg = get_arguments()
|
||
control_robot(cfg)
|
||
# control_robot()
|
||
# cfg = get_arguments()
|
||
# from agilex_robot import AgilexRobot
|
||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||
# print(robot.features.items())
|
||
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
|
||
# record(robot, cfg)
|
||
# capture = robot.capture_observation()
|
||
# import torch
|
||
# torch.save(capture, "test.pt")
|
||
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
|
||
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
|
||
# device='cpu')
|
||
# robot.send_action(action.squeeze(0))
|
||
# print() |