Files
lerobot_aloha/lerobot_aloha/common/utils/data_utils.py
2025-04-07 20:32:39 +08:00

106 lines
3.6 KiB
Python

import logging
import time
from pprint import pprint
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import log_say, has_method
from common.utils.control_utils import init_keyboard_listener, stop_recording, record_episode
def record(
robot,
cfg
) -> LeRobotDataset:
"""
Record robot data according to configuration.
Args:
robot: Robot instance
cfg: Configuration object
Returns:
LeRobotDataset: Dataset with recorded episodes
"""
# Initialize or load dataset
if cfg.resume:
dataset = LeRobotDataset(
cfg.repo_id,
root=cfg.root,
)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
else:
# Create empty dataset or load existing saved episodes
dataset = LeRobotDataset.create(
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=None,
features=robot.features,
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
# Initialize keyboard listener
listener, events = init_keyboard_listener()
# Print recording instructions
print()
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
# Record episodes
recorded_episodes = 0
while True:
if recorded_episodes >= cfg.num_episodes:
break
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
record_episode(
robot=robot,
dataset=dataset,
events=events,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
fps=cfg.fps,
single_task=cfg.single_task,
)
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
pprint("Reset the environment, stop recording")
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
pprint("Re-record episode")
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", cfg.play_sounds)
return dataset