init repo
This commit is contained in:
487
collect_data/collect_data_lerobot.py
Normal file
487
collect_data/collect_data_lerobot.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict
|
||||
from pprint import pformat
|
||||
from pprint import pprint
|
||||
|
||||
# from safetensors.torch import load_file, save_file
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.robot_devices.control_configs import (
|
||||
CalibrateControlConfig,
|
||||
ControlPipelineConfig,
|
||||
RecordControlConfig,
|
||||
RemoteRobotConfig,
|
||||
ReplayControlConfig,
|
||||
TeleoperateControlConfig,
|
||||
)
|
||||
from lerobot.common.robot_devices.control_utils import (
|
||||
# init_keyboard_listener,
|
||||
record_episode,
|
||||
stop_recording,
|
||||
is_headless
|
||||
)
|
||||
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
|
||||
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
|
||||
from lerobot.common.utils.utils import has_method, init_logging, log_say
|
||||
from lerobot.common.utils.utils import get_safe_torch_device
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
import torch
|
||||
import rospy
|
||||
import cv2
|
||||
from lerobot.configs import parser
|
||||
from agilex_robot import AgilexRobot
|
||||
|
||||
|
||||
########################################################################################
|
||||
# Control modes
|
||||
########################################################################################
|
||||
|
||||
|
||||
def predict_action(observation, policy, device, use_amp):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if "image" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
observation[name] = observation[name].unsqueeze(0)
|
||||
observation[name] = observation[name].to(device)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
def control_loop(
|
||||
robot,
|
||||
control_time_s=None,
|
||||
teleoperate=False,
|
||||
display_cameras=False,
|
||||
dataset: LeRobotDataset | None = None,
|
||||
events=None,
|
||||
policy = None,
|
||||
fps: int | None = None,
|
||||
single_task: str | None = None,
|
||||
):
|
||||
# TODO(rcadene): Add option to record logs
|
||||
# if not robot.is_connected:
|
||||
# robot.connect()
|
||||
|
||||
if events is None:
|
||||
events = {"exit_early": False}
|
||||
|
||||
if control_time_s is None:
|
||||
control_time_s = float("inf")
|
||||
|
||||
if dataset is not None and single_task is None:
|
||||
raise ValueError("You need to provide a task as argument in `single_task`.")
|
||||
|
||||
if dataset is not None and fps is not None and dataset.fps != fps:
|
||||
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).")
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
rate = rospy.Rate(fps)
|
||||
print_flag = True
|
||||
while timestamp < control_time_s and not rospy.is_shutdown():
|
||||
# print(timestamp < control_time_s)
|
||||
# print(rospy.is_shutdown())
|
||||
start_loop_t = time.perf_counter()
|
||||
|
||||
if teleoperate:
|
||||
observation, action = robot.teleop_step()
|
||||
if observation is None or action is None:
|
||||
if print_flag:
|
||||
print("sync data fail, retrying...\n")
|
||||
print_flag = False
|
||||
rate.sleep()
|
||||
continue
|
||||
else:
|
||||
# pass
|
||||
observation = robot.capture_observation()
|
||||
if policy is not None:
|
||||
pred_action = predict_action(
|
||||
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
|
||||
)
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
action = robot.send_action(pred_action)
|
||||
action = {"action": action}
|
||||
|
||||
if dataset is not None:
|
||||
frame = {**observation, **action, "task": single_task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# if display_cameras and not is_headless():
|
||||
# image_keys = [key for key in observation if "image" in key]
|
||||
# for key in image_keys:
|
||||
# if "depth" in key:
|
||||
# pass
|
||||
# else:
|
||||
# cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
|
||||
# print(1)
|
||||
# cv2.waitKey(1)
|
||||
|
||||
if display_cameras and not is_headless():
|
||||
image_keys = [key for key in observation if "image" in key]
|
||||
|
||||
# 获取屏幕分辨率(假设屏幕分辨率为 1920x1080,可以根据实际情况调整)
|
||||
screen_width = 1920
|
||||
screen_height = 1080
|
||||
|
||||
# 计算窗口的排列方式
|
||||
num_images = len(image_keys)
|
||||
max_columns = int(screen_width / 640) # 假设每个窗口宽度为 640
|
||||
rows = (num_images + max_columns - 1) // max_columns # 计算需要的行数
|
||||
columns = min(num_images, max_columns) # 实际使用的列数
|
||||
|
||||
# 遍历所有图像键并显示
|
||||
for idx, key in enumerate(image_keys):
|
||||
if "depth" in key:
|
||||
continue # 跳过深度图像
|
||||
|
||||
# 将图像从 RGB 转换为 BGR 格式
|
||||
image = cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# 创建窗口
|
||||
cv2.imshow(key, image)
|
||||
|
||||
# 计算窗口位置
|
||||
window_width = 640
|
||||
window_height = 480
|
||||
row = idx // max_columns
|
||||
col = idx % max_columns
|
||||
x_position = col * window_width
|
||||
y_position = row * window_height
|
||||
|
||||
# 移动窗口到指定位置
|
||||
cv2.moveWindow(key, x_position, y_position)
|
||||
|
||||
# 等待 1 毫秒以处理事件
|
||||
cv2.waitKey(1)
|
||||
|
||||
if fps is not None:
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
# log_control_info(robot, dt_s, fps=fps)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
if events["exit_early"]:
|
||||
events["exit_early"] = False
|
||||
break
|
||||
|
||||
|
||||
def init_keyboard_listener():
|
||||
# Allow to exit early while recording an episode or resetting the environment,
|
||||
# by tapping the right arrow key '->'. This might require a sudo permission
|
||||
# to allow your terminal to monitor keyboard events.
|
||||
events = {}
|
||||
events["exit_early"] = False
|
||||
events["record_start"] = False
|
||||
events["rerecord_episode"] = False
|
||||
events["stop_recording"] = False
|
||||
|
||||
if is_headless():
|
||||
logging.warning(
|
||||
"Headless environment detected. On-screen cameras display and keyboard inputs will not be available."
|
||||
)
|
||||
listener = None
|
||||
return listener, events
|
||||
|
||||
# Only import pynput if not in a headless environment
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if key == keyboard.Key.right:
|
||||
print("Right arrow key pressed. Exiting loop...")
|
||||
events["exit_early"] = True
|
||||
events["record_start"] = False
|
||||
elif key == keyboard.Key.left:
|
||||
print("Left arrow key pressed. Exiting loop and rerecord the last episode...")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
print("Escape key pressed. Stopping data recording...")
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.up:
|
||||
print("Up arrow pressed. Start data recording...")
|
||||
events["record_start"] = True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling key press: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
|
||||
return listener, events
|
||||
|
||||
|
||||
def stop_recording(robot, listener, display_cameras):
|
||||
|
||||
if not is_headless():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
if display_cameras:
|
||||
cv2.destroyAllWindows()
|
||||
|
||||
|
||||
def record_episode(
|
||||
robot,
|
||||
dataset,
|
||||
events,
|
||||
episode_time_s,
|
||||
display_cameras,
|
||||
policy,
|
||||
fps,
|
||||
single_task,
|
||||
):
|
||||
control_loop(
|
||||
robot=robot,
|
||||
control_time_s=episode_time_s,
|
||||
display_cameras=display_cameras,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
policy=policy,
|
||||
fps=fps,
|
||||
teleoperate=policy is None,
|
||||
single_task=single_task,
|
||||
)
|
||||
|
||||
|
||||
def record(
|
||||
robot,
|
||||
cfg
|
||||
) -> LeRobotDataset:
|
||||
# TODO(rcadene): Add option to record logs
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
cfg.repo_id,
|
||||
root=cfg.root,
|
||||
)
|
||||
if len(robot.cameras) > 0:
|
||||
dataset.start_image_writer(
|
||||
num_processes=cfg.num_image_writer_processes,
|
||||
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
# sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
|
||||
else:
|
||||
# Create empty dataset or load existing saved episodes
|
||||
# sanity_check_dataset_name(cfg.repo_id, cfg.policy)
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.repo_id,
|
||||
cfg.fps,
|
||||
root=cfg.root,
|
||||
robot=None,
|
||||
features=robot.features,
|
||||
use_videos=cfg.video,
|
||||
image_writer_processes=cfg.num_image_writer_processes,
|
||||
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
# policy = None
|
||||
|
||||
# if not robot.is_connected:
|
||||
# robot.connect()
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
# Execute a few seconds without recording to:
|
||||
# 1. teleoperate the robot to move it in starting position if no policy provided,
|
||||
# 2. give times to the robot devices to connect and start synchronizing,
|
||||
# 3. place the cameras windows on screen
|
||||
enable_teleoperation = policy is None
|
||||
log_say("Warmup record", cfg.play_sounds)
|
||||
print()
|
||||
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
||||
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
|
||||
|
||||
# if has_method(robot, "teleop_safety_stop"):
|
||||
# robot.teleop_safety_stop()
|
||||
|
||||
recorded_episodes = 0
|
||||
while True:
|
||||
if recorded_episodes >= cfg.num_episodes:
|
||||
break
|
||||
|
||||
# if events["record_start"]:
|
||||
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
||||
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
||||
record_episode(
|
||||
robot=robot,
|
||||
dataset=dataset,
|
||||
events=events,
|
||||
episode_time_s=cfg.episode_time_s,
|
||||
display_cameras=cfg.display_cameras,
|
||||
policy=policy,
|
||||
fps=cfg.fps,
|
||||
single_task=cfg.single_task,
|
||||
)
|
||||
|
||||
# Execute a few seconds without recording to give time to manually reset the environment
|
||||
# Current code logic doesn't allow to teleoperate during this time.
|
||||
# TODO(rcadene): add an option to enable teleoperation during reset
|
||||
# Skip reset for the last episode to be recorded
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment", cfg.play_sounds)
|
||||
pprint("Reset the environment, stop recording")
|
||||
# reset_environment(robot, events, cfg.reset_time_s, cfg.fps)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode", cfg.play_sounds)
|
||||
pprint("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
||||
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
||||
stop_recording(robot, listener, cfg.display_cameras)
|
||||
|
||||
if cfg.push_to_hub:
|
||||
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
||||
|
||||
log_say("Exiting", cfg.play_sounds)
|
||||
return dataset
|
||||
|
||||
|
||||
def replay(
|
||||
robot: AgilexRobot,
|
||||
cfg,
|
||||
):
|
||||
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
|
||||
# TODO(rcadene): Add option to record logs
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode])
|
||||
actions = dataset.hf_dataset.select_columns("action")
|
||||
|
||||
# if not robot.is_connected:
|
||||
# robot.connect()
|
||||
|
||||
log_say("Replaying episode", cfg.play_sounds, blocking=True)
|
||||
for idx in range(dataset.num_frames):
|
||||
start_episode_t = time.perf_counter()
|
||||
|
||||
action = actions[idx]["action"]
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / cfg.fps - dt_s)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
# log_control_info(robot, dt_s, fps=cfg.fps)
|
||||
|
||||
|
||||
import argparse
|
||||
def get_arguments():
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.fps = 30
|
||||
args.resume = False
|
||||
args.repo_id = "move_the_bottle_from_the_right_to_the_scale_right"
|
||||
args.root = "/home/ubuntu/LYT/aloha_lerobot/data4"
|
||||
args.episode = 0 # replay episode
|
||||
args.num_image_writer_processes = 0
|
||||
args.num_image_writer_threads_per_camera = 4
|
||||
args.video = True
|
||||
args.num_episodes = 100
|
||||
args.episode_time_s = 30000
|
||||
args.play_sounds = False
|
||||
args.display_cameras = True
|
||||
args.single_task = "move the bottle from the right to the scale right"
|
||||
args.use_depth_image = False
|
||||
args.use_base = False
|
||||
args.push_to_hub = False
|
||||
args.policy = None
|
||||
# args.teleoprate = True
|
||||
args.control_type = "record"
|
||||
# args.control_type = "replay"
|
||||
return args
|
||||
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
# def control_robot(cfg: ControlPipelineConfig):
|
||||
# init_logging()
|
||||
# logging.info(pformat(asdict(cfg)))
|
||||
|
||||
# # robot = make_robot_from_config(cfg.robot)
|
||||
# from agilex_robot import AgilexRobot
|
||||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
|
||||
# if isinstance(cfg.control, RecordControlConfig):
|
||||
# print(cfg.control)
|
||||
# record(robot, cfg.control)
|
||||
# elif isinstance(cfg.control, ReplayControlConfig):
|
||||
# replay(robot, cfg.control)
|
||||
|
||||
# # if robot.is_connected:
|
||||
# # # Disconnect manually to avoid a "Core dump" during process
|
||||
# # # termination due to camera threads not properly exiting.
|
||||
# # robot.disconnect()
|
||||
|
||||
|
||||
# @parser.wrap()
|
||||
def control_robot(cfg):
|
||||
|
||||
# robot = make_robot_from_config(cfg.robot)
|
||||
from agilex_robot import AgilexRobot
|
||||
robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
|
||||
if cfg.control_type == "record":
|
||||
record(robot, cfg)
|
||||
elif cfg.control_type == "replay":
|
||||
replay(robot, cfg)
|
||||
|
||||
# if robot.is_connected:
|
||||
# # Disconnect manually to avoid a "Core dump" during process
|
||||
# # termination due to camera threads not properly exiting.
|
||||
# robot.disconnect()
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = get_arguments()
|
||||
control_robot(cfg)
|
||||
# control_robot()
|
||||
# cfg = get_arguments()
|
||||
# from agilex_robot import AgilexRobot
|
||||
# robot = AgilexRobot(config_file="/home/ubuntu/LYT/aloha_lerobot/collect_data/agilex.yaml", args=cfg)
|
||||
# print(robot.features.items())
|
||||
# print([key for key, ft in robot.features.items() if ft["dtype"] == "video"])
|
||||
# record(robot, cfg)
|
||||
# capture = robot.capture_observation()
|
||||
# import torch
|
||||
# torch.save(capture, "test.pt")
|
||||
# action = torch.tensor([[ 0.0277, 0.0167, 0.0142, -0.1628, 0.1473, -0.0296, 0.0238, -0.1094,
|
||||
# 0.0109, 0.0139, -0.1591, -0.1490, -0.1650, -0.0980]],
|
||||
# device='cpu')
|
||||
# robot.send_action(action.squeeze(0))
|
||||
# print()
|
||||
Reference in New Issue
Block a user