From 8b983992069108b5d4113719b6aa623ee56d978e Mon Sep 17 00:00:00 2001 From: Simon Alibert Date: Thu, 8 May 2025 13:15:37 +0200 Subject: [PATCH] Add record --- lerobot/common/datasets/lerobot_dataset.py | 13 +- lerobot/common/datasets/utils.py | 66 ++++- lerobot/common/utils/control_utils.py | 8 +- lerobot/record.py | 325 +++++++++++++++++++++ 4 files changed, 387 insertions(+), 25 deletions(-) create mode 100644 lerobot/record.py diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index cf0fb0463..c79e49d94 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -72,7 +72,7 @@ from lerobot.common.datasets.video_utils import ( get_safe_default_codec, get_video_info, ) -from lerobot.common.robots.utils import Robot +from lerobot.common.robots import Robot CODEBASE_VERSION = "v2.1" @@ -785,7 +785,7 @@ class LeRobotDataset(torch.utils.data.Dataset): else: self.image_writer.save_image(image=image, fpath=fpath) - def add_frame(self, frame: dict) -> None: + def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None: """ This function only adds the frame to the episode_buffer. Apart from images — which are written in a temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method @@ -803,17 +803,14 @@ class LeRobotDataset(torch.utils.data.Dataset): # Automatically add frame_index and timestamp to episode buffer frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + if timestamp is None: + timestamp = frame_index / self.fps self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) + self.episode_buffer["task"].append(task) # Add frame features to episode_buffer for key in frame: - if key == "task": - # Note: we associate the task in natural language to its task index during `save_episode` - self.episode_buffer["task"].append(frame["task"]) - continue - if key not in self.features: raise ValueError( f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index f8b016cdc..31a3cca78 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -40,7 +40,7 @@ from lerobot.common.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) -from lerobot.common.robots.utils import Robot +from lerobot.common.robots import Robot from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.configs.types import DictLike, FeatureType, PolicyFeature @@ -387,6 +387,52 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: return datasets.Features(hf_features) +def _validate_feature_names(features: dict[str, dict]) -> None: + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + features = {} + joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float} + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts: + features[f"{prefix}.joints"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.cameras.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.cameras.")] + + return frame + + def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: camera_ft = {} if robot.cameras: @@ -699,16 +745,12 @@ class IterableNamespace(SimpleNamespace): def validate_frame(frame: dict, features: dict): - optional_features = {"timestamp"} - expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} - actual_features = set(frame.keys()) + expected_features = set(features) - set(DEFAULT_FEATURES) + actual_features = set(frame) - error_message = validate_features_presence(actual_features, expected_features, optional_features) + error_message = validate_features_presence(actual_features, expected_features) - if "task" in frame: - error_message += validate_feature_string("task", frame["task"]) - - common_features = actual_features & (expected_features | optional_features) + common_features = actual_features & expected_features for name in common_features - {"task"}: error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) @@ -716,12 +758,10 @@ def validate_frame(frame: dict, features: dict): raise ValueError(error_message) -def validate_features_presence( - actual_features: set[str], expected_features: set[str], optional_features: set[str] -): +def validate_features_presence(actual_features: set[str], expected_features: set[str]): error_message = "" missing_features = expected_features - actual_features - extra_features = actual_features - (expected_features | optional_features) + extra_features = actual_features - expected_features if missing_features or extra_features: error_message += "Feature mismatch in `frame` dictionary:\n" diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py index 12524b699..a39b4fcc2 100644 --- a/lerobot/common/utils/control_utils.py +++ b/lerobot/common/utils/control_utils.py @@ -31,9 +31,9 @@ from termcolor import colored from lerobot.common.datasets.image_writer import safe_stop_image_writer from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.utils import get_features_from_robot +from lerobot.common.datasets.utils import DEFAULT_FEATURES from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.robots.utils import Robot +from lerobot.common.robots import Robot from lerobot.common.utils.robot_utils import busy_wait from lerobot.common.utils.utils import get_safe_torch_device, has_method @@ -327,12 +327,12 @@ def sanity_check_dataset_name(repo_id, policy_cfg): def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool + dataset: LeRobotDataset, robot: Robot, fps: int, features: dict ) -> None: fields = [ ("robot_type", dataset.meta.robot_type, robot.robot_type), ("fps", dataset.fps, fps), - ("features", dataset.features, get_features_from_robot(robot, use_videos)), + ("features", dataset.features, {**features, **DEFAULT_FEATURES}), ] mismatches = [] diff --git a/lerobot/record.py b/lerobot/record.py new file mode 100644 index 000000000..004e6502b --- /dev/null +++ b/lerobot/record.py @@ -0,0 +1,325 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from pprint import pformat + +import draccus +import numpy as np +import rerun as rr + +from lerobot.common.cameras import ( # noqa: F401 + CameraConfig, # noqa: F401 + intel, + opencv, +) +from lerobot.common.datasets.image_writer import safe_stop_image_writer +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + Teleoperator, + TeleoperatorConfig, + make_teleoperator_from_config, +) +from lerobot.common.utils.control_utils import ( + init_keyboard_listener, + is_headless, + predict_action, + sanity_check_dataset_name, + sanity_check_dataset_robot_compatibility, +) +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import ( + get_safe_torch_device, + init_logging, + log_say, +) +from lerobot.common.utils.visualization_utils import _init_rerun +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig + +from .common.teleoperators import koch_leader, so100_leader # noqa: F401 + + +@dataclass +class DatasetRecordConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") + single_task: str + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + # Limit the frames per second. By default, uses the policy fps. + fps: int = 30 + # Number of seconds for data recording for each episode. + episode_time_s: int | float = 60 + # Number of seconds for resetting the environment after each episode. + reset_time_s: int | float = 60 + # Number of episodes to record. + num_episodes: int = 50 + # Encode frames in the dataset into video + video: bool = True + # Upload dataset to Hugging Face hub. + push_to_hub: bool = True + # Upload on private repository on the Hugging Face hub. + private: bool = False + # Add tags to your dataset on the hub. + tags: list[str] | None = None + # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; + # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes + # and threads depends on your system. We recommend 4 threads per camera with 0 processes. + # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. + num_image_writer_processes: int = 0 + # Number of threads writing the frames as png images on disk, per camera. + # Too many threads might cause unstable teleoperation fps due to main thread being blocked. + # Not enough threads might cause low camera fps. + num_image_writer_threads_per_camera: int = 4 + + def __post_init__(self): + if self.single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + +@dataclass +class RecordConfig: + robot: RobotConfig + dataset: DatasetRecordConfig + # Whether to control the robot with a teleoperator + teleop: TeleoperatorConfig | None = None + # Whether to control the robot with a policy + policy: PreTrainedConfig | None = None + # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. + warmup_time_s: int | float = 10 + # Display all cameras on screen + display_data: bool = False + # Use vocal synthesis to read events. + play_sounds: bool = True + # Resume recording on an existing dataset. + resume: bool = False + + def __post_init__(self): + if bool(self.teleop) == bool(self.policy): + raise ValueError("Choose either a policy or a teleoperator to control the robot") + + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + + +@safe_stop_image_writer +def record_loop( + robot: Robot, + events: dict, + dataset: LeRobotDataset | None = None, + teleop: Teleoperator | None = None, + policy: PreTrainedPolicy | None = None, + control_time_s: int | None = None, + fps: int | None = None, + single_task: str | None = None, + display_data: bool = False, +): + if dataset is not None and fps is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") + + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < control_time_s: + start_loop_t = time.perf_counter() + + observation = robot.get_observation() + + if policy is not None: + action = predict_action( + observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + else: + action = teleop.get_action() + + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + sent_action = robot.send_action(action) + + if dataset is not None: + observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") + frame = {**observation_frame, **action_frame} + dataset.add_frame(frame, task=single_task) + + if display_data: + for obs, val in observation.items(): + if isinstance(val, float): + rr.log(f"observation.{obs}", rr.Scalar(val)) + elif isinstance(val, np.ndarray): + rr.log(f"observation.{obs}", rr.Image(val), static=True) + for act, val in action.items(): + if isinstance(val, float): + rr.log(f"action.{act}", rr.Scalar(val)) + + if fps is not None: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + dt_s = time.perf_counter() - start_loop_t + # log_control_info(robot, dt_s, fps=fps) + + timestamp = time.perf_counter() - start_episode_t + if events["exit_early"]: + events["exit_early"] = False + break + + +@draccus.wrap() +def record(cfg: RecordConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + if cfg.display_data: + _init_rerun(session_name="recording") + + robot = make_robot_from_config(cfg.robot) + teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None + + action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video) + obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video) + dataset_features = {**action_features, **obs_features} + + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + ) + # for key, ft in dataset_features.items(): + # for property in ["dtype", "shape", "names"]: + # if ft[property] != dataset.features[key][property]: + # raise ValueError(ft) + + if hasattr(robot, "cameras") and len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) + else: + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + + robot.connect() + if teleop is not None: + teleop.connect() + + listener, events = init_keyboard_listener() + + # Execute a few seconds without recording to: + # 1. teleoperate the robot to move it in starting position if no policy provided, + # 2. give times to the robot devices to connect and start synchronizing, + # 3. place the cameras windows on screen + # enable_teleoperation = policy is None + # log_say("Warmup record", cfg.play_sounds) + # record_loop( + # robot=robot, + # control_time_s=cfg.warmup_time_s, + # display_data=cfg.display_data, + # events=events, + # fps=cfg.dataset.fps, + # teleoperate=enable_teleoperation, + # ) + # warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.dataset.fps) + + for recorded_episodes in range(cfg.dataset.num_episodes): + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + teleop=teleop, + policy=policy, + dataset=dataset, + control_time_s=cfg.dataset.episode_time_s, + fps=cfg.dataset.fps, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) + + # Execute a few seconds without recording to give time to manually reset the environment + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + teleop=teleop, + control_time_s=cfg.dataset.reset_time_s, + fps=cfg.dataset.fps, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) + # reset_environment(robot, events, cfg.dataset.reset_time_s, cfg.dataset.fps) + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + + if events["stop_recording"]: + break + + log_say("Stop recording", cfg.play_sounds, blocking=True) + + robot.disconnect() + teleop.disconnect() + + if not is_headless() and listener is not None: + listener.stop() + + if cfg.dataset.push_to_hub: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + + log_say("Exiting", cfg.play_sounds) + return dataset + + +if __name__ == "__main__": + record()