Add record

This commit is contained in:
Simon Alibert
2025-05-08 13:15:37 +02:00
parent 237b14a6ec
commit 8b98399206
4 changed files with 387 additions and 25 deletions

View File

@@ -72,7 +72,7 @@ from lerobot.common.datasets.video_utils import (
get_safe_default_codec,
get_video_info,
)
from lerobot.common.robots.utils import Robot
from lerobot.common.robots import Robot
CODEBASE_VERSION = "v2.1"
@@ -785,7 +785,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
else:
self.image_writer.save_image(image=image, fpath=fpath)
def add_frame(self, frame: dict) -> None:
def add_frame(self, frame: dict, task: str, timestamp: float | None = None) -> None:
"""
This function only adds the frame to the episode_buffer. Apart from images — which are written in a
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
@@ -803,17 +803,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Automatically add frame_index and timestamp to episode buffer
frame_index = self.episode_buffer["size"]
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
if timestamp is None:
timestamp = frame_index / self.fps
self.episode_buffer["frame_index"].append(frame_index)
self.episode_buffer["timestamp"].append(timestamp)
self.episode_buffer["task"].append(task)
# Add frame features to episode_buffer
for key in frame:
if key == "task":
# Note: we associate the task in natural language to its task index during `save_episode`
self.episode_buffer["task"].append(frame["task"])
continue
if key not in self.features:
raise ValueError(
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."

View File

@@ -40,7 +40,7 @@ from lerobot.common.datasets.backward_compatibility import (
BackwardCompatibilityError,
ForwardCompatibilityError,
)
from lerobot.common.robots.utils import Robot
from lerobot.common.robots import Robot
from lerobot.common.utils.utils import is_valid_numpy_dtype_string
from lerobot.configs.types import DictLike, FeatureType, PolicyFeature
@@ -387,6 +387,52 @@ def get_hf_features_from_features(features: dict) -> datasets.Features:
return datasets.Features(hf_features)
def _validate_feature_names(features: dict[str, dict]) -> None:
invalid_features = {name: ft for name, ft in features.items() if "/" in name}
if invalid_features:
raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.")
def hw_to_dataset_features(
hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True
) -> dict[str, dict]:
features = {}
joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float}
cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)}
if joint_fts:
features[f"{prefix}.joints"] = {
"dtype": "float32",
"shape": (len(joint_fts),),
"names": list(joint_fts),
}
for key, shape in cam_fts.items():
features[f"{prefix}.cameras.{key}"] = {
"dtype": "video" if use_video else "image",
"shape": shape,
"names": ["height", "width", "channels"],
}
_validate_feature_names(features)
return features
def build_dataset_frame(
ds_features: dict[str, dict], values: dict[str, Any], prefix: str
) -> dict[str, np.ndarray]:
frame = {}
for key, ft in ds_features.items():
if key in DEFAULT_FEATURES or not key.startswith(prefix):
continue
elif ft["dtype"] == "float32" and len(ft["shape"]) == 1:
frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32)
elif ft["dtype"] in ["image", "video"]:
frame[key] = values[key.removeprefix(f"{prefix}.cameras.")]
return frame
def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict:
camera_ft = {}
if robot.cameras:
@@ -699,16 +745,12 @@ class IterableNamespace(SimpleNamespace):
def validate_frame(frame: dict, features: dict):
optional_features = {"timestamp"}
expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"}
actual_features = set(frame.keys())
expected_features = set(features) - set(DEFAULT_FEATURES)
actual_features = set(frame)
error_message = validate_features_presence(actual_features, expected_features, optional_features)
error_message = validate_features_presence(actual_features, expected_features)
if "task" in frame:
error_message += validate_feature_string("task", frame["task"])
common_features = actual_features & (expected_features | optional_features)
common_features = actual_features & expected_features
for name in common_features - {"task"}:
error_message += validate_feature_dtype_and_shape(name, features[name], frame[name])
@@ -716,12 +758,10 @@ def validate_frame(frame: dict, features: dict):
raise ValueError(error_message)
def validate_features_presence(
actual_features: set[str], expected_features: set[str], optional_features: set[str]
):
def validate_features_presence(actual_features: set[str], expected_features: set[str]):
error_message = ""
missing_features = expected_features - actual_features
extra_features = actual_features - (expected_features | optional_features)
extra_features = actual_features - expected_features
if missing_features or extra_features:
error_message += "Feature mismatch in `frame` dictionary:\n"

View File

@@ -31,9 +31,9 @@ from termcolor import colored
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import get_features_from_robot
from lerobot.common.datasets.utils import DEFAULT_FEATURES
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robots.utils import Robot
from lerobot.common.robots import Robot
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import get_safe_torch_device, has_method
@@ -327,12 +327,12 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
def sanity_check_dataset_robot_compatibility(
dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool
dataset: LeRobotDataset, robot: Robot, fps: int, features: dict
) -> None:
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
("fps", dataset.fps, fps),
("features", dataset.features, get_features_from_robot(robot, use_videos)),
("features", dataset.features, {**features, **DEFAULT_FEATURES}),
]
mismatches = []

325
lerobot/record.py Normal file
View File

@@ -0,0 +1,325 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from pprint import pformat
import draccus
import numpy as np
import rerun as rr
from lerobot.common.cameras import ( # noqa: F401
CameraConfig, # noqa: F401
intel,
opencv,
)
from lerobot.common.datasets.image_writer import safe_stop_image_writer
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.robots import ( # noqa: F401
Robot,
RobotConfig,
koch_follower,
make_robot_from_config,
so100_follower,
)
from lerobot.common.teleoperators import ( # noqa: F401
Teleoperator,
TeleoperatorConfig,
make_teleoperator_from_config,
)
from lerobot.common.utils.control_utils import (
init_keyboard_listener,
is_headless,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.common.utils.robot_utils import busy_wait
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_logging,
log_say,
)
from lerobot.common.utils.visualization_utils import _init_rerun
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from .common.teleoperators import koch_leader, so100_leader # noqa: F401
@dataclass
class DatasetRecordConfig:
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
repo_id: str
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
single_task: str
# Root directory where the dataset will be stored (e.g. 'dataset/path').
root: str | Path | None = None
# Limit the frames per second. By default, uses the policy fps.
fps: int = 30
# Number of seconds for data recording for each episode.
episode_time_s: int | float = 60
# Number of seconds for resetting the environment after each episode.
reset_time_s: int | float = 60
# Number of episodes to record.
num_episodes: int = 50
# Encode frames in the dataset into video
video: bool = True
# Upload dataset to Hugging Face hub.
push_to_hub: bool = True
# Upload on private repository on the Hugging Face hub.
private: bool = False
# Add tags to your dataset on the hub.
tags: list[str] | None = None
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
num_image_writer_processes: int = 0
# Number of threads writing the frames as png images on disk, per camera.
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
# Not enough threads might cause low camera fps.
num_image_writer_threads_per_camera: int = 4
def __post_init__(self):
if self.single_task is None:
raise ValueError("You need to provide a task as argument in `single_task`.")
@dataclass
class RecordConfig:
robot: RobotConfig
dataset: DatasetRecordConfig
# Whether to control the robot with a teleoperator
teleop: TeleoperatorConfig | None = None
# Whether to control the robot with a policy
policy: PreTrainedConfig | None = None
# Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.
warmup_time_s: int | float = 10
# Display all cameras on screen
display_data: bool = False
# Use vocal synthesis to read events.
play_sounds: bool = True
# Resume recording on an existing dataset.
resume: bool = False
def __post_init__(self):
if bool(self.teleop) == bool(self.policy):
raise ValueError("Choose either a policy or a teleoperator to control the robot")
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
@safe_stop_image_writer
def record_loop(
robot: Robot,
events: dict,
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | None = None,
policy: PreTrainedPolicy | None = None,
control_time_s: int | None = None,
fps: int | None = None,
single_task: str | None = None,
display_data: bool = False,
):
if dataset is not None and fps is not None and dataset.fps != fps:
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
timestamp = 0
start_episode_t = time.perf_counter()
while timestamp < control_time_s:
start_loop_t = time.perf_counter()
observation = robot.get_observation()
if policy is not None:
action = predict_action(
observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp
)
else:
action = teleop.get_action()
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
sent_action = robot.send_action(action)
if dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
frame = {**observation_frame, **action_frame}
dataset.add_frame(frame, task=single_task)
if display_data:
for obs, val in observation.items():
if isinstance(val, float):
rr.log(f"observation.{obs}", rr.Scalar(val))
elif isinstance(val, np.ndarray):
rr.log(f"observation.{obs}", rr.Image(val), static=True)
for act, val in action.items():
if isinstance(val, float):
rr.log(f"action.{act}", rr.Scalar(val))
if fps is not None:
dt_s = time.perf_counter() - start_loop_t
busy_wait(1 / fps - dt_s)
dt_s = time.perf_counter() - start_loop_t
# log_control_info(robot, dt_s, fps=fps)
timestamp = time.perf_counter() - start_episode_t
if events["exit_early"]:
events["exit_early"] = False
break
@draccus.wrap()
def record(cfg: RecordConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
if cfg.display_data:
_init_rerun(session_name="recording")
robot = make_robot_from_config(cfg.robot)
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
dataset_features = {**action_features, **obs_features}
if cfg.resume:
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
)
# for key, ft in dataset_features.items():
# for property in ["dtype", "shape", "names"]:
# if ft[property] != dataset.features[key][property]:
# raise ValueError(ft)
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.dataset.num_image_writer_processes,
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.dataset.fps,
root=cfg.dataset.root,
robot_type=robot.name,
features=dataset_features,
use_videos=cfg.dataset.video,
image_writer_processes=cfg.dataset.num_image_writer_processes,
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
robot.connect()
if teleop is not None:
teleop.connect()
listener, events = init_keyboard_listener()
# Execute a few seconds without recording to:
# 1. teleoperate the robot to move it in starting position if no policy provided,
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
# enable_teleoperation = policy is None
# log_say("Warmup record", cfg.play_sounds)
# record_loop(
# robot=robot,
# control_time_s=cfg.warmup_time_s,
# display_data=cfg.display_data,
# events=events,
# fps=cfg.dataset.fps,
# teleoperate=enable_teleoperation,
# )
# warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.dataset.fps)
for recorded_episodes in range(cfg.dataset.num_episodes):
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
teleop=teleop,
policy=policy,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
fps=cfg.dataset.fps,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# Execute a few seconds without recording to give time to manually reset the environment
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", cfg.play_sounds)
record_loop(
robot=robot,
events=events,
teleop=teleop,
control_time_s=cfg.dataset.reset_time_s,
fps=cfg.dataset.fps,
single_task=cfg.dataset.single_task,
display_data=cfg.display_data,
)
# reset_environment(robot, events, cfg.dataset.reset_time_s, cfg.dataset.fps)
if events["rerecord_episode"]:
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode()
if events["stop_recording"]:
break
log_say("Stop recording", cfg.play_sounds, blocking=True)
robot.disconnect()
teleop.disconnect()
if not is_headless() and listener is not None:
listener.stop()
if cfg.dataset.push_to_hub:
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
log_say("Exiting", cfg.play_sounds)
return dataset
if __name__ == "__main__":
record()