106 lines
3.6 KiB
Python
106 lines
3.6 KiB
Python
import logging
|
|
import time
|
|
from pprint import pprint
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.utils.utils import log_say, has_method
|
|
from common.utils.control_utils import init_keyboard_listener, stop_recording, record_episode
|
|
|
|
|
|
def record(
|
|
robot,
|
|
cfg
|
|
) -> LeRobotDataset:
|
|
"""
|
|
Record robot data according to configuration.
|
|
|
|
Args:
|
|
robot: Robot instance
|
|
cfg: Configuration object
|
|
|
|
Returns:
|
|
LeRobotDataset: Dataset with recorded episodes
|
|
"""
|
|
# Initialize or load dataset
|
|
if cfg.resume:
|
|
dataset = LeRobotDataset(
|
|
cfg.repo_id,
|
|
root=cfg.root,
|
|
)
|
|
if len(robot.cameras) > 0:
|
|
dataset.start_image_writer(
|
|
num_processes=cfg.num_image_writer_processes,
|
|
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
|
)
|
|
else:
|
|
# Create empty dataset or load existing saved episodes
|
|
dataset = LeRobotDataset.create(
|
|
cfg.repo_id,
|
|
cfg.fps,
|
|
root=cfg.root,
|
|
robot=None,
|
|
features=robot.features,
|
|
use_videos=cfg.video,
|
|
image_writer_processes=cfg.num_image_writer_processes,
|
|
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
|
|
)
|
|
|
|
# Load pretrained policy
|
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
|
|
# Initialize keyboard listener
|
|
listener, events = init_keyboard_listener()
|
|
|
|
# Print recording instructions
|
|
print()
|
|
print(f"开始记录轨迹,共需要记录{cfg.num_episodes}条\n每条轨迹的最长时间为{cfg.episode_time_s}frame\n按右方向键代表当前轨迹结束录制\n按上方面键代表当前轨迹开始录制\n按左方向键代表当前轨迹重新录制\n按ESC方向键代表退出轨迹录制\n")
|
|
|
|
# Record episodes
|
|
recorded_episodes = 0
|
|
while True:
|
|
if recorded_episodes >= cfg.num_episodes:
|
|
break
|
|
|
|
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
|
pprint(f"Recording episode {dataset.num_episodes}, total episodes is {cfg.num_episodes}")
|
|
record_episode(
|
|
robot=robot,
|
|
dataset=dataset,
|
|
events=events,
|
|
episode_time_s=cfg.episode_time_s,
|
|
display_cameras=cfg.display_cameras,
|
|
policy=policy,
|
|
fps=cfg.fps,
|
|
single_task=cfg.single_task,
|
|
)
|
|
|
|
# Skip reset for the last episode to be recorded
|
|
if not events["stop_recording"] and (
|
|
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
|
|
):
|
|
log_say("Reset the environment", cfg.play_sounds)
|
|
pprint("Reset the environment, stop recording")
|
|
|
|
if events["rerecord_episode"]:
|
|
log_say("Re-record episode", cfg.play_sounds)
|
|
pprint("Re-record episode")
|
|
events["rerecord_episode"] = False
|
|
events["exit_early"] = False
|
|
dataset.clear_episode_buffer()
|
|
continue
|
|
|
|
dataset.save_episode()
|
|
recorded_episodes += 1
|
|
|
|
if events["stop_recording"]:
|
|
break
|
|
|
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
|
stop_recording(robot, listener, cfg.display_cameras)
|
|
|
|
if cfg.push_to_hub:
|
|
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
|
|
|
|
log_say("Exiting", cfg.play_sounds)
|
|
return dataset
|