Add record

This commit is contained in:
Simon Alibert
2025-05-08 13:15:37 +02:00
parent 237b14a6ec
commit 8b98399206
4 changed files with 387 additions and 25 deletions

325
lerobot/record.py Normal file
View File

@@ -0,0 +1,325 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import draccus
import numpy as np
import rerun as rr
from lerobot.common.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
intel,
opencv,
)
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
)
from lerobot.common.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
make_teleoperator_from_config,
)
from lerobot.common.utils.control_utils import (
init_keyboard_listener,
is_headless,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_logging,
log_say,
)
from lerobot.common.utils.visualization_utils import _init_rerun
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from .common.teleoperators import koch_leader, so100_leader # noqa: F401
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
def __post_init__(self):
if self.single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
@dataclass
class RecordConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Whether to control the robot with a teleoperator
teleop: TeleoperatorConfig | None = None
# Whether to control the robot with a policy
policy: PreTrainedConfig | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
warmup_time_s: int | float = 10
# Display all cameras on screen
display_data: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
def __post_init__(self):
if bool(self.teleop) == bool(self.policy):
raise ValueError("Choose either a policy or a teleoperator to control the robot")
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
@safe_stop_image_writer
def record_loop(
robot: Robot,
events: dict,
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | None = None,
policy: PreTrainedPolicy | None = None,
control_time_s: int | None = None,
fps: int | None = None,
single_task: str | None = None,
display_data: bool = False,
):
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
observation = robot.get_observation()
if policy is not None:
action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
else:
action = teleop.get_action()
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
if display_data:
for obs, val in observation.items():
if isinstance(val, float):
rr.log(f"observation.{obs}", rr.Scalar(val))
elif isinstance(val, np.ndarray):
rr.log(f"observation.{obs}", rr.Image(val), static=True)
for act, val in action.items():
if isinstance(val, float):
rr.log(f"action.{act}", rr.Scalar(val))
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
@draccus.wrap()
def record(cfg: RecordConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
_init_rerun(session_name="recording")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
dataset_features = {**action_features, **obs_features}
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
)
# for key, ft in dataset_features.items():
# for property in ["dtype", "shape", "names"]:
# if ft[property] != dataset.features[key][property]:
# raise ValueError(ft)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
robot.connect()
if teleop is not None:
teleop.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
# enable_teleoperation = policy is None
# log_say("Warmup record", cfg.play_sounds)
# record_loop(
# robot=robot,
# control_time_s=cfg.warmup_time_s,
# display_data=cfg.display_data,
# events=events,
# fps=cfg.dataset.fps,
# teleoperate=enable_teleoperation,
# )
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.dataset.fps)
for recorded_episodes in range(cfg.dataset.num_episodes):
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
teleop=teleop,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
fps=cfg.dataset.fps,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
fps=cfg.dataset.fps,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# reset_environment(robot, events, cfg.dataset.reset_time_s, cfg.dataset.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
robot.disconnect()
teleop.disconnect()
if not is_headless() and listener is not None:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
log_say("Exiting", cfg.play_sounds)
return dataset
if __name__ == "__main__":
record()