forked from tangger/lerobot
348 lines
13 KiB
Python
348 lines
13 KiB
Python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy.
|
|
|
|
Example:
|
|
|
|
```shell
|
|
python -m lerobot.record \
|
|
--robot.type=so100_follower \
|
|
--robot.port=/dev/tty.usbmodem58760431541 \
|
|
--robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \
|
|
--robot.id=black \
|
|
--dataset.repo_id=aliberts/record-test \
|
|
--dataset.num_episodes=2 \
|
|
--dataset.single_task="Grab the cube" \
|
|
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
|
# --teleop.type=so100_leader \
|
|
# --teleop.port=/dev/tty.usbmodem58760431551 \
|
|
# --teleop.id=blue \
|
|
# <- Policy optional if you want to record with a policy \
|
|
# --policy.path=${HF_USER}/my_policy \
|
|
```
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from pprint import pformat
|
|
|
|
import numpy as np
|
|
import rerun as rr
|
|
|
|
from lerobot.common.cameras import ( # noqa: F401
|
|
CameraConfig, # noqa: F401
|
|
)
|
|
from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
|
from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
|
from lerobot.common.datasets.image_writer import safe_stop_image_writer
|
|
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
|
from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
|
from lerobot.common.policies.factory import make_policy
|
|
from lerobot.common.policies.pretrained import PreTrainedPolicy
|
|
from lerobot.common.robots import ( # noqa: F401
|
|
Robot,
|
|
RobotConfig,
|
|
koch_follower,
|
|
make_robot_from_config,
|
|
so100_follower,
|
|
so101_follower,
|
|
)
|
|
from lerobot.common.teleoperators import ( # noqa: F401
|
|
Teleoperator,
|
|
TeleoperatorConfig,
|
|
make_teleoperator_from_config,
|
|
)
|
|
from lerobot.common.utils.control_utils import (
|
|
init_keyboard_listener,
|
|
is_headless,
|
|
predict_action,
|
|
sanity_check_dataset_name,
|
|
sanity_check_dataset_robot_compatibility,
|
|
)
|
|
from lerobot.common.utils.robot_utils import busy_wait
|
|
from lerobot.common.utils.utils import (
|
|
get_safe_torch_device,
|
|
init_logging,
|
|
log_say,
|
|
)
|
|
from lerobot.common.utils.visualization_utils import _init_rerun
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
|
|
from .common.teleoperators import koch_leader, so100_leader, so101_leader # noqa: F401
|
|
|
|
|
|
@dataclass
|
|
class DatasetRecordConfig:
|
|
# Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).
|
|
repo_id: str
|
|
# A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.")
|
|
single_task: str
|
|
# Root directory where the dataset will be stored (e.g. 'dataset/path').
|
|
root: str | Path | None = None
|
|
# Limit the frames per second.
|
|
fps: int = 30
|
|
# Number of seconds for data recording for each episode.
|
|
episode_time_s: int | float = 60
|
|
# Number of seconds for resetting the environment after each episode.
|
|
reset_time_s: int | float = 60
|
|
# Number of episodes to record.
|
|
num_episodes: int = 50
|
|
# Encode frames in the dataset into video
|
|
video: bool = True
|
|
# Upload dataset to Hugging Face hub.
|
|
push_to_hub: bool = True
|
|
# Upload on private repository on the Hugging Face hub.
|
|
private: bool = False
|
|
# Add tags to your dataset on the hub.
|
|
tags: list[str] | None = None
|
|
# Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only;
|
|
# set to ≥1 to use subprocesses, each using threads to write images. The best number of processes
|
|
# and threads depends on your system. We recommend 4 threads per camera with 0 processes.
|
|
# If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses.
|
|
num_image_writer_processes: int = 0
|
|
# Number of threads writing the frames as png images on disk, per camera.
|
|
# Too many threads might cause unstable teleoperation fps due to main thread being blocked.
|
|
# Not enough threads might cause low camera fps.
|
|
num_image_writer_threads_per_camera: int = 4
|
|
|
|
def __post_init__(self):
|
|
if self.single_task is None:
|
|
raise ValueError("You need to provide a task as argument in `single_task`.")
|
|
|
|
|
|
@dataclass
|
|
class RecordConfig:
|
|
robot: RobotConfig
|
|
dataset: DatasetRecordConfig
|
|
# Whether to control the robot with a teleoperator
|
|
teleop: TeleoperatorConfig | None = None
|
|
# Whether to control the robot with a policy
|
|
policy: PreTrainedConfig | None = None
|
|
# Display all cameras on screen
|
|
display_data: bool = False
|
|
# Use vocal synthesis to read events.
|
|
play_sounds: bool = True
|
|
# Resume recording on an existing dataset.
|
|
resume: bool = False
|
|
|
|
def __post_init__(self):
|
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
|
|
if self.teleop is None and self.policy is None:
|
|
raise ValueError("Choose a policy, a teleoperator or both to control the robot")
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|
|
|
|
|
|
@safe_stop_image_writer
|
|
def record_loop(
|
|
robot: Robot,
|
|
events: dict,
|
|
fps: int,
|
|
dataset: LeRobotDataset | None = None,
|
|
teleop: Teleoperator | None = None,
|
|
policy: PreTrainedPolicy | None = None,
|
|
control_time_s: int | None = None,
|
|
single_task: str | None = None,
|
|
display_data: bool = False,
|
|
):
|
|
if dataset is not None and dataset.fps != fps:
|
|
raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).")
|
|
|
|
# if policy is given it needs cleaning up
|
|
if policy is not None:
|
|
policy.reset()
|
|
|
|
timestamp = 0
|
|
start_episode_t = time.perf_counter()
|
|
while timestamp < control_time_s:
|
|
start_loop_t = time.perf_counter()
|
|
|
|
if events["exit_early"]:
|
|
events["exit_early"] = False
|
|
break
|
|
|
|
observation = robot.get_observation()
|
|
|
|
if policy is not None or dataset is not None:
|
|
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
|
|
|
if policy is not None:
|
|
action_values = predict_action(
|
|
observation_frame,
|
|
policy,
|
|
get_safe_torch_device(policy.config.device),
|
|
policy.config.use_amp,
|
|
task=single_task,
|
|
robot_type=robot.robot_type,
|
|
)
|
|
action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)}
|
|
elif policy is None and teleop is not None:
|
|
action = teleop.get_action()
|
|
else:
|
|
logging.info(
|
|
"No policy or teleoperator provided, skipping action generation."
|
|
"This is likely to happen when resetting the environment without a teleop device."
|
|
"The robot won't be at its rest position at the start of the next episode."
|
|
)
|
|
continue
|
|
|
|
# Action can eventually be clipped using `max_relative_target`,
|
|
# so action actually sent is saved in the dataset.
|
|
sent_action = robot.send_action(action)
|
|
|
|
if dataset is not None:
|
|
action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
|
|
frame = {**observation_frame, **action_frame}
|
|
dataset.add_frame(frame, task=single_task)
|
|
|
|
if display_data:
|
|
for obs, val in observation.items():
|
|
if isinstance(val, float):
|
|
rr.log(f"observation.{obs}", rr.Scalar(val))
|
|
elif isinstance(val, np.ndarray):
|
|
rr.log(f"observation.{obs}", rr.Image(val), static=True)
|
|
for act, val in action.items():
|
|
if isinstance(val, float):
|
|
rr.log(f"action.{act}", rr.Scalar(val))
|
|
|
|
dt_s = time.perf_counter() - start_loop_t
|
|
busy_wait(1 / fps - dt_s)
|
|
|
|
timestamp = time.perf_counter() - start_episode_t
|
|
|
|
|
|
@parser.wrap()
|
|
def record(cfg: RecordConfig) -> LeRobotDataset:
|
|
init_logging()
|
|
logging.info(pformat(asdict(cfg)))
|
|
if cfg.display_data:
|
|
_init_rerun(session_name="recording")
|
|
|
|
robot = make_robot_from_config(cfg.robot)
|
|
teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None
|
|
|
|
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
|
dataset_features = {**action_features, **obs_features}
|
|
|
|
if cfg.resume:
|
|
dataset = LeRobotDataset(
|
|
cfg.dataset.repo_id,
|
|
root=cfg.dataset.root,
|
|
)
|
|
|
|
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
|
dataset.start_image_writer(
|
|
num_processes=cfg.dataset.num_image_writer_processes,
|
|
num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
|
)
|
|
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features)
|
|
else:
|
|
# Create empty dataset or load existing saved episodes
|
|
sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy)
|
|
dataset = LeRobotDataset.create(
|
|
cfg.dataset.repo_id,
|
|
cfg.dataset.fps,
|
|
root=cfg.dataset.root,
|
|
robot_type=robot.name,
|
|
features=dataset_features,
|
|
use_videos=cfg.dataset.video,
|
|
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
|
)
|
|
|
|
# Load pretrained policy
|
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
|
|
|
robot.connect()
|
|
if teleop is not None:
|
|
teleop.connect()
|
|
|
|
listener, events = init_keyboard_listener()
|
|
|
|
for recorded_episodes in range(cfg.dataset.num_episodes):
|
|
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
|
|
record_loop(
|
|
robot=robot,
|
|
events=events,
|
|
fps=cfg.dataset.fps,
|
|
teleop=teleop,
|
|
policy=policy,
|
|
dataset=dataset,
|
|
control_time_s=cfg.dataset.episode_time_s,
|
|
single_task=cfg.dataset.single_task,
|
|
display_data=cfg.display_data,
|
|
)
|
|
|
|
# Execute a few seconds without recording to give time to manually reset the environment
|
|
# Skip reset for the last episode to be recorded
|
|
if not events["stop_recording"] and (
|
|
(recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"]
|
|
):
|
|
log_say("Reset the environment", cfg.play_sounds)
|
|
record_loop(
|
|
robot=robot,
|
|
events=events,
|
|
fps=cfg.dataset.fps,
|
|
teleop=teleop,
|
|
control_time_s=cfg.dataset.reset_time_s,
|
|
single_task=cfg.dataset.single_task,
|
|
display_data=cfg.display_data,
|
|
)
|
|
|
|
if events["rerecord_episode"]:
|
|
log_say("Re-record episode", cfg.play_sounds)
|
|
events["rerecord_episode"] = False
|
|
events["exit_early"] = False
|
|
dataset.clear_episode_buffer()
|
|
continue
|
|
|
|
dataset.save_episode()
|
|
|
|
if events["stop_recording"]:
|
|
break
|
|
|
|
log_say("Stop recording", cfg.play_sounds, blocking=True)
|
|
|
|
robot.disconnect()
|
|
if teleop is not None:
|
|
teleop.disconnect()
|
|
|
|
if not is_headless() and listener is not None:
|
|
listener.stop()
|
|
|
|
if cfg.dataset.push_to_hub:
|
|
dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private)
|
|
|
|
log_say("Exiting", cfg.play_sounds)
|
|
return dataset
|
|
|
|
|
|
if __name__ == "__main__":
|
|
record()
|