Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -8,30 +8,42 @@ Examples of usage:
- Recalibrate your robot:
```bash
python lerobot/scripts/control_robot.py calibrate
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--control.type=calibrate
```
- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C:
```bash
python lerobot/scripts/control_robot.py teleoperate
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--robot.cameras='{}' \
--control.type=teleoperate
# Remove the cameras from the robot definition. They are not used in 'teleoperate' anyway.
python lerobot/scripts/control_robot.py teleoperate --robot-overrides '~cameras'
# Add the cameras from the robot definition to visualize them:
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--control.type=teleoperate
```
- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency:
```bash
python lerobot/scripts/control_robot.py teleoperate \
--fps 30
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--control.type=teleoperate \
--control.fps=30
```
- Record one episode in order to test replay:
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--repo-id $USER/koch_test \
--num-episodes 1 \
--run-compute-stats 0
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=$USER/koch_test \
--control.num_episodes=1 \
--control.push_to_hub=True
```
- Visualize dataset:
@@ -44,21 +56,25 @@ python lerobot/scripts/visualize_dataset.py \
- Replay this test episode:
```bash
python lerobot/scripts/control_robot.py replay \
--fps 30 \
--repo-id $USER/koch_test \
--episode 0
--robot.type=so100 \
--control.type=replay \
--control.fps=30 \
--control.repo_id=$USER/koch_test \
--control.episode=0
```
- Record a full dataset in order to train a policy, with 2 seconds of warmup,
30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes:
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--repo-id $USER/koch_pick_place_lego \
--num-episodes 50 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
--robot.type=so100 \
--control.type=record \
--control.fps 30 \
--control.repo_id=$USER/koch_pick_place_lego \
--control.num_episodes=50 \
--control.warmup_time_s=2 \
--control.episode_time_s=30 \
--control.reset_time_s=10
```
**NOTE**: You can use your keyboard to control data recording flow.
@@ -68,44 +84,55 @@ python lerobot/scripts/control_robot.py record \
- Tap escape key 'esc' to stop the data recording.
This might require a sudo permission to allow your terminal to monitor keyboard events.
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--resume 1`.
If the dataset you want to extend is not on the hub, you also need to add `--local-files-only 1`.
**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`.
If the dataset you want to extend is not on the hub, you also need to add `--control.local_files_only=true`.
- Train on this dataset with the ACT policy:
```bash
python lerobot/scripts/train.py \
policy=act_koch_real \
env=koch_real \
dataset_repo_id=$USER/koch_pick_place_lego \
hydra.run.dir=outputs/train/act_koch_real
--dataset.repo_id=${HF_USER}/koch_pick_place_lego \
--policy.type=act \
--output_dir=outputs/train/act_koch_pick_place_lego \
--job_name=act_koch_pick_place_lego \
--device=cuda \
--wandb.enable=true
```
- Run the pretrained policy on the robot:
```bash
python lerobot/scripts/control_robot.py record \
--fps 30 \
--repo-id $USER/eval_act_koch_real \
--num-episodes 10 \
--warmup-time-s 2 \
--episode-time-s 30 \
--reset-time-s 10
-p outputs/train/act_koch_real/checkpoints/080000/pretrained_model
python lerobot/scripts/control_robot.py \
--robot.type=so100 \
--control.type=record \
--control.fps=30 \
--control.single_task="Grasp a lego block and put it in the bin." \
--control.repo_id=$USER/eval_act_koch_pick_place_lego \
--control.num_episodes=10 \
--control.warmup_time_s=2 \
--control.episode_time_s=30 \
--control.reset_time_s=10 \
--control.push_to_hub=true \
--control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/080000/pretrained_model
```
"""
import argparse
import logging
import time
from pathlib import Path
from typing import List
from dataclasses import asdict
from pprint import pformat
# from safetensors.torch import load_file, save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.policies.factory import make_policy
from lerobot.common.robot_devices.control_configs import (
CalibrateControlConfig,
ControlPipelineConfig,
RecordControlConfig,
ReplayControlConfig,
TeleoperateControlConfig,
)
from lerobot.common.robot_devices.control_utils import (
control_loop,
has_method,
init_keyboard_listener,
init_policy,
log_control_info,
record_episode,
reset_environment,
@@ -114,10 +141,10 @@ from lerobot.common.robot_devices.control_utils import (
stop_recording,
warmup_record,
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config
from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int
from lerobot.common.utils.utils import has_method, init_logging, log_say
from lerobot.configs import parser
########################################################################################
# Control modes
@@ -125,7 +152,7 @@ from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say,
@safe_disconnect
def calibrate(robot: Robot, arms: list[str] | None):
def calibrate(robot: Robot, cfg: CalibrateControlConfig):
# TODO(aliberts): move this code in robots' classes
if robot.robot_type.startswith("stretch"):
if not robot.is_connected:
@@ -134,9 +161,7 @@ def calibrate(robot: Robot, arms: list[str] | None):
robot.home()
return
if arms is None:
arms = robot.available_arms
arms = robot.available_arms if cfg.arms is None else cfg.arms
unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms]
available_arms_str = " ".join(robot.available_arms)
unknown_arms_str = " ".join(unknown_arms)
@@ -171,91 +196,50 @@ def calibrate(robot: Robot, arms: list[str] | None):
@safe_disconnect
def teleoperate(
robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False
):
def teleoperate(robot: Robot, cfg: TeleoperateControlConfig):
control_loop(
robot,
control_time_s=teleop_time_s,
fps=fps,
control_time_s=cfg.teleop_time_s,
fps=cfg.fps,
teleoperate=True,
display_cameras=display_cameras,
display_cameras=cfg.display_cameras,
)
@safe_disconnect
def record(
robot: Robot,
root: Path,
repo_id: str,
single_task: str,
pretrained_policy_name_or_path: str | None = None,
policy_overrides: List[str] | None = None,
fps: int | None = None,
warmup_time_s: int | float = 2,
episode_time_s: int | float = 10,
reset_time_s: int | float = 5,
num_episodes: int = 50,
video: bool = True,
run_compute_stats: bool = True,
push_to_hub: bool = True,
tags: list[str] | None = None,
num_image_writer_processes: int = 0,
num_image_writer_threads_per_camera: int = 4,
display_cameras: bool = True,
play_sounds: bool = True,
resume: bool = False,
# TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument
local_files_only: bool = False,
cfg: RecordControlConfig,
) -> LeRobotDataset:
# TODO(rcadene): Add option to record logs
listener = None
events = None
policy = None
device = None
use_amp = None
if single_task:
task = single_task
else:
raise NotImplementedError("Only single-task recording is supported for now")
# Load pretrained policy
if pretrained_policy_name_or_path is not None:
policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides)
if fps is None:
fps = policy_fps
logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).")
elif fps != policy_fps:
logging.warning(
f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})."
)
if resume:
if cfg.resume:
dataset = LeRobotDataset(
repo_id,
root=root,
local_files_only=local_files_only,
cfg.repo_id,
root=cfg.root,
local_files_only=cfg.local_files_only,
)
dataset.start_image_writer(
num_processes=num_image_writer_processes,
num_threads=num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, fps, video)
if len(robot.cameras) > 0:
dataset.start_image_writer(
num_processes=cfg.num_image_writer_processes,
num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video)
else:
# Create empty dataset or load existing saved episodes
sanity_check_dataset_name(repo_id, policy)
sanity_check_dataset_name(cfg.repo_id, cfg.policy)
dataset = LeRobotDataset.create(
repo_id,
fps,
root=root,
cfg.repo_id,
cfg.fps,
root=cfg.root,
robot=robot,
use_videos=video,
image_writer_processes=num_image_writer_processes,
image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras),
use_videos=cfg.video,
image_writer_processes=cfg.num_image_writer_processes,
image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras),
)
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, cfg.device, ds_meta=dataset.meta)
if not robot.is_connected:
robot.connect()
@@ -266,33 +250,28 @@ def record(
# 2. give times to the robot devices to connect and start synchronizing,
# 3. place the cameras windows on screen
enable_teleoperation = policy is None
log_say("Warmup record", play_sounds)
warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps)
log_say("Warmup record", cfg.play_sounds)
warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_cameras, cfg.fps)
if has_method(robot, "teleop_safety_stop"):
robot.teleop_safety_stop()
recorded_episodes = 0
while True:
if recorded_episodes >= num_episodes:
if recorded_episodes >= cfg.num_episodes:
break
# TODO(aliberts): add task prompt for multitask here. Might need to temporarily disable event if
# input() messes with them.
# if multi_task:
# task = input("Enter your task description: ")
log_say(f"Recording episode {dataset.num_episodes}", play_sounds)
log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds)
record_episode(
dataset=dataset,
robot=robot,
events=events,
episode_time_s=episode_time_s,
display_cameras=display_cameras,
episode_time_s=cfg.episode_time_s,
display_cameras=cfg.display_cameras,
policy=policy,
device=device,
use_amp=use_amp,
fps=fps,
device=cfg.device,
use_amp=cfg.use_amp,
fps=cfg.fps,
)
# Execute a few seconds without recording to give time to manually reset the environment
@@ -300,59 +279,56 @@ def record(
# TODO(rcadene): add an option to enable teleoperation during reset
# Skip reset for the last episode to be recorded
if not events["stop_recording"] and (
(recorded_episodes < num_episodes - 1) or events["rerecord_episode"]
(recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"]
):
log_say("Reset the environment", play_sounds)
reset_environment(robot, events, reset_time_s)
log_say("Reset the environment", cfg.play_sounds)
reset_environment(robot, events, cfg.reset_time_s)
if events["rerecord_episode"]:
log_say("Re-record episode", play_sounds)
log_say("Re-record episode", cfg.play_sounds)
events["rerecord_episode"] = False
events["exit_early"] = False
dataset.clear_episode_buffer()
continue
dataset.save_episode(task)
dataset.save_episode(cfg.single_task)
recorded_episodes += 1
if events["stop_recording"]:
break
log_say("Stop recording", play_sounds, blocking=True)
stop_recording(robot, listener, display_cameras)
log_say("Stop recording", cfg.play_sounds, blocking=True)
stop_recording(robot, listener, cfg.display_cameras)
if run_compute_stats:
if cfg.run_compute_stats:
logging.info("Computing dataset statistics")
dataset.consolidate(run_compute_stats)
dataset.consolidate(cfg.run_compute_stats)
if push_to_hub:
dataset.push_to_hub(tags=tags)
if cfg.push_to_hub:
dataset.push_to_hub(tags=cfg.tags, private=cfg.private)
log_say("Exiting", play_sounds)
log_say("Exiting", cfg.play_sounds)
return dataset
@safe_disconnect
def replay(
robot: Robot,
root: Path,
repo_id: str,
episode: int,
fps: int | None = None,
play_sounds: bool = True,
local_files_only: bool = False,
cfg: ReplayControlConfig,
):
# TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset
# TODO(rcadene): Add option to record logs
dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only)
dataset = LeRobotDataset(
cfg.repo_id, root=cfg.root, episodes=[cfg.episode], local_files_only=cfg.local_files_only
)
actions = dataset.hf_dataset.select_columns("action")
if not robot.is_connected:
robot.connect()
log_say("Replaying episode", play_sounds, blocking=True)
log_say("Replaying episode", cfg.play_sounds, blocking=True)
for idx in range(dataset.num_frames):
start_episode_t = time.perf_counter()
@@ -360,216 +336,33 @@ def replay(
robot.send_action(action)
dt_s = time.perf_counter() - start_episode_t
busy_wait(1 / fps - dt_s)
busy_wait(1 / cfg.fps - dt_s)
dt_s = time.perf_counter() - start_episode_t
log_control_info(robot, dt_s, fps=fps)
log_control_info(robot, dt_s, fps=cfg.fps)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
subparsers = parser.add_subparsers(dest="mode", required=True)
# Set common options for all the subparsers
base_parser = argparse.ArgumentParser(add_help=False)
base_parser.add_argument(
"--robot-path",
type=str,
default="lerobot/configs/robot/koch.yaml",
help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.",
)
base_parser.add_argument(
"--robot-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_calib = subparsers.add_parser("calibrate", parents=[base_parser])
parser_calib.add_argument(
"--arms",
type=str,
nargs="*",
help="List of arms to calibrate (e.g. `--arms left_follower right_follower left_leader`)",
)
parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser])
parser_teleop.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_teleop.add_argument(
"--display-cameras",
type=int,
default=1,
help="Display all cameras on screen (set to 1 to display or 0).",
)
parser_record = subparsers.add_parser("record", parents=[base_parser])
task_args = parser_record.add_mutually_exclusive_group(required=True)
parser_record.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
task_args.add_argument(
"--single-task",
type=str,
help="A short but accurate description of the task performed during the recording.",
)
# TODO(aliberts): add multi-task support
# task_args.add_argument(
# "--multi-task",
# type=int,
# help="You will need to enter the task performed at the start of each episode.",
# )
parser_record.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
)
parser_record.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_record.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_record.add_argument(
"--warmup-time-s",
type=int,
default=10,
help="Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize.",
)
parser_record.add_argument(
"--episode-time-s",
type=int,
default=60,
help="Number of seconds for data recording for each episode.",
)
parser_record.add_argument(
"--reset-time-s",
type=int,
default=60,
help="Number of seconds for resetting the environment after each episode.",
)
parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.")
parser_record.add_argument(
"--run-compute-stats",
type=int,
default=1,
help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.",
)
parser_record.add_argument(
"--push-to-hub",
type=int,
default=1,
help="Upload dataset to Hugging Face hub.",
)
parser_record.add_argument(
"--tags",
type=str,
nargs="*",
help="Add tags to your dataset on the hub.",
)
parser_record.add_argument(
"--num-image-writer-processes",
type=int,
default=0,
help=(
"Number of subprocesses handling the saving of frames as PNGs. Set to 0 to use threads only; "
"set to ≥1 to use subprocesses, each using threads to write images. The best number of processes "
"and threads depends on your system. We recommend 4 threads per camera with 0 processes. "
"If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses."
),
)
parser_record.add_argument(
"--num-image-writer-threads-per-camera",
type=int,
default=4,
help=(
"Number of threads writing the frames as png images on disk, per camera. "
"Too many threads might cause unstable teleoperation fps due to main thread being blocked. "
"Not enough threads might cause low camera fps."
),
)
parser_record.add_argument(
"--resume",
type=int,
default=0,
help="Resume recording on an existing dataset.",
)
parser_record.add_argument(
"-p",
"--pretrained-policy-name-or-path",
type=str,
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`."
),
)
parser_record.add_argument(
"--policy-overrides",
type=str,
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
parser_replay = subparsers.add_parser("replay", parents=[base_parser])
parser_replay.add_argument(
"--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)"
)
parser_replay.add_argument(
"--root",
type=Path,
default=None,
help="Root directory where the dataset will be stored (e.g. 'dataset/path').",
)
parser_replay.add_argument(
"--repo-id",
type=str,
default="lerobot/test",
help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).",
)
parser_replay.add_argument(
"--local-files-only",
type=int,
default=0,
help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.",
)
parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.")
args = parser.parse_args()
@parser.wrap()
def control_robot(cfg: ControlPipelineConfig):
init_logging()
logging.info(pformat(asdict(cfg)))
control_mode = args.mode
robot_path = args.robot_path
robot_overrides = args.robot_overrides
kwargs = vars(args)
del kwargs["mode"]
del kwargs["robot_path"]
del kwargs["robot_overrides"]
robot = make_robot_from_config(cfg.robot)
robot_cfg = init_hydra_config(robot_path, robot_overrides)
robot = make_robot(robot_cfg)
if control_mode == "calibrate":
calibrate(robot, **kwargs)
elif control_mode == "teleoperate":
teleoperate(robot, **kwargs)
elif control_mode == "record":
record(robot, **kwargs)
elif control_mode == "replay":
replay(robot, **kwargs)
if isinstance(cfg.control, CalibrateControlConfig):
calibrate(robot, cfg.control)
elif isinstance(cfg.control, TeleoperateControlConfig):
teleoperate(robot, cfg.control)
elif isinstance(cfg.control, RecordControlConfig):
record(robot, cfg.control)
elif isinstance(cfg.control, ReplayControlConfig):
replay(robot, cfg.control)
if robot.is_connected:
# Disconnect manually to avoid a "Core dump" during process
# termination due to camera threads not properly exiting.
robot.disconnect()
if __name__ == "__main__":
control_robot()

View File

@@ -90,11 +90,12 @@ from lerobot.common.robot_devices.control_utils import (
sanity_check_dataset_robot_compatibility,
stop_recording,
)
from lerobot.common.robot_devices.robots.factory import make_robot
from lerobot.common.robot_devices.robots.utils import Robot
from lerobot.common.robot_devices.robots.utils import Robot, make_robot
from lerobot.common.robot_devices.utils import busy_wait
from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say
raise NotImplementedError("This script is currently deactivated")
DEFAULT_FEATURES = {
"next.reward": {
"dtype": "float32",
@@ -227,7 +228,7 @@ def record(
shape = env.observation_space[key].shape
if not key.startswith("observation.image."):
key = "observation.image." + key
features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape}
features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape}
for key, obs_key in state_keys_dict.items():
features[key] = {
@@ -504,7 +505,7 @@ if __name__ == "__main__":
# make gym env
env_cfg = init_hydra_config(env_config_path)
importlib.import_module(f"gym_{env_cfg.env.name}")
importlib.import_module(f"gym_{env_cfg.env.type}")
def env_constructor():
return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym)
@@ -515,6 +516,7 @@ if __name__ == "__main__":
if control_mode in ["teleoperate", "record"]:
# make robot
robot_overrides = ["~cameras", "~follower_arms"]
# TODO(rcadene): remove
robot_cfg = init_hydra_config(robot_path, robot_overrides)
robot = make_robot(robot_cfg)
robot.connect()

View File

@@ -21,67 +21,69 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
for 10 episodes.
```
python lerobot/scripts/eval.py -p lerobot/diffusion_pusht eval.n_episodes=10
python lerobot/scripts/eval.py \
--policy.path=lerobot/diffusion_pusht \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
```
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
```
python lerobot/scripts/eval.py \
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
eval.n_episodes=10
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
```
Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and
`model.safetensors`.
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments
with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes`
nested under `eval` in the `config.yaml` found at
https://huggingface.co/lerobot/diffusion_pusht/tree/main.
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
import argparse
import json
import logging
import threading
import time
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Callable
import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import Tensor, nn
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_hydra_config,
init_logging,
inside_slurm,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
def rollout(
env: gym.vector.VectorEnv,
policy: Policy,
policy: PreTrainedPolicy,
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -208,7 +210,7 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: torch.nn.Module,
policy: PreTrainedPolicy,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -232,7 +234,9 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
assert isinstance(policy, Policy)
if not isinstance(policy, PreTrainedPolicy):
raise ValueError(policy)
start = time.time()
policy.eval()
@@ -442,66 +446,43 @@ def _compile_episode_data(
return data_dict
def main(
pretrained_policy_path: Path | None = None,
hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)."
)
if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
@parser.wrap()
def eval(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
set_global_seed(cfg.seed)
log_output_dir(out_dir)
log_output_dir(cfg.output_dir)
logging.info("Making environment.")
env = make_env(hydra_cfg)
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
assert isinstance(policy, nn.Module)
policy = make_policy(
cfg=cfg.policy,
device=device,
env_cfg=cfg.env,
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
info = eval_policy(
env,
policy,
hydra_cfg.eval.n_episodes,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(out_dir) / "videos",
start_seed=hydra_cfg.seed,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
)
print(info["aggregated"])
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
env.close()
@@ -509,76 +490,6 @@ def main(
logging.info("End of eval")
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
group.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--out-dir",
help=(
"Where to save the evaluation outputs. If not provided, outputs are saved in "
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
),
)
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
if args.pretrained_policy_name_or_path is None:
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else:
pretrained_policy_path = get_pretrained_policy_path(
args.pretrained_policy_name_or_path, revision=args.revision
)
main(
pretrained_policy_path=pretrained_policy_path,
out_dir=args.out_dir,
config_overrides=args.overrides,
)
eval()

View File

@@ -0,0 +1,71 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Once you have trained a policy with our training script (lerobot/scripts/train.py), use this script to push it
to the hub.
Example:
```bash
python lerobot/scripts/push_pretrained.py \
--pretrained_path=outputs/train/act_aloha_sim_transfer_cube_human/checkpoints/last/pretrained_model \
--repo_id=lerobot/act_aloha_sim_transfer_cube_human
```
"""
from dataclasses import dataclass
from pathlib import Path
import draccus
from huggingface_hub import HfApi
@dataclass
class PushPreTrainedConfig:
pretrained_path: Path
repo_id: str
branch: str | None = None
private: bool = False
exist_ok: bool = False
@draccus.wrap()
def main(cfg: PushPreTrainedConfig):
hub_api = HfApi()
hub_api.create_repo(
repo_id=cfg.repo_id,
private=cfg.private,
repo_type="model",
exist_ok=cfg.exist_ok,
)
if cfg.branch:
hub_api.create_branch(
repo_id=cfg.repo_id,
branch=cfg.branch,
repo_type="model",
exist_ok=cfg.exist_ok,
)
hub_api.upload_folder(
repo_id=cfg.repo_id,
folder_path=cfg.pretrained_path,
repo_type="model",
revision=cfg.branch,
)
if __name__ == "__main__":
main()

View File

@@ -18,92 +18,36 @@ import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from copy import deepcopy
from pathlib import Path
from dataclasses import asdict
from pprint import pformat
from threading import Lock
import hydra
import numpy as np
import torch
from deepdiff import DeepDiff
from omegaconf import DictConfig, ListConfig, OmegaConf
from termcolor import colored
from torch import nn
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from lerobot.common.datasets.factory import make_dataset, resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.optim.factory import load_training_state, make_optimizer_and_scheduler
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import PolicyWithUpdate
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
has_method,
init_logging,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.scripts.eval import eval_policy
def make_optimizer_and_scheduler(cfg, policy):
if cfg.policy.name == "act":
optimizer_params_dicts = [
{
"params": [
p
for n, p in policy.named_parameters()
if not n.startswith("model.backbone") and p.requires_grad
]
},
{
"params": [
p
for n, p in policy.named_parameters()
if n.startswith("model.backbone") and p.requires_grad
],
"lr": cfg.training.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay
)
lr_scheduler = None
elif cfg.policy.name == "diffusion":
optimizer = torch.optim.Adam(
policy.diffusion.parameters(),
cfg.training.lr,
cfg.training.adam_betas,
cfg.training.adam_eps,
cfg.training.adam_weight_decay,
)
from diffusers.optimization import get_scheduler
lr_scheduler = get_scheduler(
cfg.training.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.training.lr_warmup_steps,
num_training_steps=cfg.training.offline_steps,
)
elif policy.name == "tdmpc":
optimizer = torch.optim.Adam(policy.parameters(), cfg.training.lr)
lr_scheduler = None
elif cfg.policy.name == "vqbet":
from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler
optimizer = VQBeTOptimizer(policy, cfg)
lr_scheduler = VQBeTScheduler(optimizer, cfg)
else:
raise NotImplementedError()
return optimizer, lr_scheduler
def update_policy(
policy,
batch,
@@ -145,7 +89,7 @@ def update_policy(
if lr_scheduler is not None:
lr_scheduler.step()
if isinstance(policy, PolicyWithUpdate):
if has_method(policy, "update"):
# To possibly update an internal buffer (for instance an Exponential Moving Average like in TDMPC).
policy.update()
@@ -161,7 +105,9 @@ def update_policy(
return info
def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
def log_train_info(
logger: Logger, info: dict, step: int, cfg: TrainPipelineConfig, dataset: LeRobotDataset, is_online: bool
):
loss = info["loss"]
grad_norm = info["grad_norm"]
lr = info["lr"]
@@ -170,7 +116,7 @@ def log_train_info(logger: Logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
@@ -207,7 +153,7 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
# A sample is an (observation,action) pair, where observation and action
# can be on multiple timestamps. In a batch, we have `batch_size`` number of samples.
num_samples = (step + 1) * cfg.training.batch_size
num_samples = (step + 1) * cfg.batch_size
avg_samples_per_ep = dataset.num_frames / dataset.num_episodes
num_episodes = num_samples / avg_samples_per_ep
num_epochs = num_samples / dataset.num_frames
@@ -234,74 +180,17 @@ def log_eval_info(logger, info, step, cfg, dataset, is_online):
logger.log_dict(info, step, mode="eval")
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
if job_name is None:
raise NotImplementedError()
@parser.wrap()
def train(cfg: TrainPipelineConfig):
cfg.validate()
init_logging()
logging.info(pformat(OmegaConf.to_container(cfg)))
if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig):
raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.")
# If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need
# to check for any differences between the provided config and the checkpoint's config.
if cfg.resume:
if not Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
"You have set resume=True, but there is no model checkpoint in "
f"{Logger.get_last_checkpoint_dir(out_dir)}"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
logging.info(
colored(
"You have set resume=True, indicating that you wish to resume a run",
color="yellow",
attrs=["bold"],
)
)
# Get the configuration file from the last checkpoint.
checkpoint_cfg = init_hydra_config(checkpoint_cfg_path)
# Check for differences between the checkpoint configuration and provided configuration.
# Hack to resolve the delta_timestamps ahead of time in order to properly diff.
resolve_delta_timestamps(cfg)
diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg))
# Ignore the `resume` and parameters.
if "values_changed" in diff and "root['resume']" in diff["values_changed"]:
del diff["values_changed"]["root['resume']"]
# Log a warning about differences between the checkpoint configuration and the provided
# configuration.
if len(diff) > 0:
logging.warning(
"At least one difference was detected between the checkpoint configuration and "
f"the provided configuration: \n{pformat(diff)}\nNote that the checkpoint configuration "
"takes precedence.",
)
# Use the checkpoint config instead of the provided config (but keep `resume` parameter).
cfg = checkpoint_cfg
cfg.resume = True
elif Logger.get_last_checkpoint_dir(out_dir).exists():
raise RuntimeError(
f"The configured output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. If "
"you meant to resume training, please use `resume=true` in your command or yaml configuration."
)
if cfg.eval.batch_size > cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({cfg.eval.batch_size} > {cfg.eval.n_episodes}). As a result, {cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={cfg.eval.n_episodes}`)."
)
logging.info(pformat(asdict(cfg)))
# log metrics to terminal and wandb
logger = Logger(cfg, out_dir, wandb_job_name=job_name)
logger = Logger(cfg)
set_global_seed(cfg.seed)
if cfg.seed is not None:
set_global_seed(cfg.seed)
# Check device is available
device = get_safe_torch_device(cfg.device, log=True)
@@ -309,65 +198,58 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_dataset")
logging.info("Creating dataset")
offline_dataset = make_dataset(cfg)
if isinstance(offline_dataset, MultiLeRobotDataset):
logging.info(
"Multiple datasets were provided. Applied the following index mapping to the provided datasets: "
f"{pformat(offline_dataset.repo_id_to_index , indent=2)}"
)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None
if cfg.training.eval_freq > 0:
logging.info("make_env")
eval_env = make_env(cfg)
if cfg.eval_freq > 0 and cfg.env is not None:
logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size)
logging.info("make_policy")
logging.info("Creating policy")
policy = make_policy(
hydra_cfg=cfg,
dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
cfg=cfg.policy,
device=device,
ds_meta=offline_dataset.meta,
)
assert isinstance(policy, nn.Module)
# Create optimizer and scheduler
# Temporary hack to move optimizer out of policy
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
grad_scaler = GradScaler(enabled=cfg.use_amp)
grad_scaler = GradScaler(device, enabled=cfg.use_amp)
step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume:
step = logger.load_last_training_state(optimizer, lr_scheduler)
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})")
logging.info(f"{cfg.training.online_steps=}")
log_output_dir(cfg.output_dir)
if cfg.env is not None:
logging.info(f"{cfg.env.task=}")
logging.info(f"{cfg.offline.steps=} ({format_big_number(cfg.offline.steps)})")
logging.info(f"{cfg.online.steps=}")
logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})")
logging.info(f"{offline_dataset.num_episodes=}")
logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})")
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
# Note: this helper will be used in offline and online training loops.
def evaluate_and_checkpoint_if_needed(step, is_online):
_num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps)))
def evaluate_and_checkpoint_if_needed(step: int, is_online: bool):
_num_digits = max(6, len(str(cfg.offline.steps + cfg.online.steps)))
step_identifier = f"{step:0{_num_digits}d}"
if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0:
if cfg.env is not None and cfg.eval_freq > 0 and step % cfg.eval_freq == 0:
logging.info(f"Eval policy at step {step}")
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
assert eval_env is not None
eval_info = eval_policy(
eval_env,
policy,
cfg.eval.n_episodes,
videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}",
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_identifier}",
max_episodes_rendered=4,
start_seed=cfg.seed,
)
@@ -376,28 +258,27 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger.log_video(eval_info["video_paths"][0], step, mode="eval")
logging.info("Resume training")
if cfg.training.save_checkpoint and (
step % cfg.training.save_freq == 0
or step == cfg.training.offline_steps + cfg.training.online_steps
if cfg.save_checkpoint and (
step % cfg.save_freq == 0 or step == cfg.offline.steps + cfg.online.steps
):
logging.info(f"Checkpoint policy after step {step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
# needed (choose 6 as a minimum for consistency without being overkill).
logger.save_checkpoint(
step,
step_identifier,
policy,
optimizer,
lr_scheduler,
identifier=step_identifier,
)
logging.info("Resume training")
# create dataloader for offline training
if cfg.training.get("drop_n_last_frames"):
if getattr(cfg.policy, "drop_n_last_frames", None):
shuffle = False
sampler = EpisodeAwareSampler(
offline_dataset.episode_data_index,
drop_n_last_frames=cfg.training.drop_n_last_frames,
drop_n_last_frames=cfg.policy.drop_n_last_frames,
shuffle=True,
)
else:
@@ -405,8 +286,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
sampler = None
dataloader = torch.utils.data.DataLoader(
offline_dataset,
num_workers=cfg.training.num_workers,
batch_size=cfg.training.batch_size,
num_workers=cfg.num_workers,
batch_size=cfg.batch_size,
shuffle=shuffle,
sampler=sampler,
pin_memory=device.type != "cpu",
@@ -416,7 +297,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy.train()
offline_step = 0
for _ in range(step, cfg.training.offline_steps):
for _ in range(step, cfg.offline.steps):
if offline_step == 0:
logging.info("Start offline training on a fixed dataset")
@@ -431,7 +312,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
@@ -439,7 +320,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
train_info["dataloading_s"] = dataloading_s
if step % cfg.training.log_freq == 0:
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
@@ -449,7 +330,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
step += 1
offline_step += 1 # noqa: SIM113
if cfg.training.online_steps == 0:
if cfg.online.steps == 0:
if eval_env:
eval_env.close()
logging.info("End of training")
@@ -458,8 +339,8 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Online training.
# Create an env dedicated to online episodes collection from policy rollout.
online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size)
resolve_delta_timestamps(cfg)
online_env = make_env(cfg.env, n_envs=cfg.online.rollout_batch_size)
delta_timestamps = resolve_delta_timestamps(cfg.policy, offline_dataset.meta)
online_buffer_path = logger.log_dir / "online_buffer"
if cfg.resume and not online_buffer_path.exists():
# If we are resuming a run, we default to the data shapes and buffer capacity from the saved online
@@ -473,31 +354,39 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
online_dataset = OnlineBuffer(
online_buffer_path,
data_spec={
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()},
**{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.input_features.items()
},
**{
key: {"shape": ft.shape, "dtype": np.dtype("float32")}
for key, ft in policy.config.output_features.items()
},
"next.reward": {"shape": (), "dtype": np.dtype("float32")},
"next.done": {"shape": (), "dtype": np.dtype("?")},
"task_index": {"shape": (), "dtype": np.dtype("int64")},
# FIXME: 'next.success' is expected by pusht env but not xarm
"next.success": {"shape": (), "dtype": np.dtype("?")},
},
buffer_capacity=cfg.training.online_buffer_capacity,
buffer_capacity=cfg.online.buffer_capacity,
fps=online_env.unwrapped.metadata["render_fps"],
delta_timestamps=cfg.training.delta_timestamps,
delta_timestamps=delta_timestamps,
)
# If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this
# makes it possible to do online rollouts in parallel with training updates).
online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy
online_rollout_policy = deepcopy(policy) if cfg.online.do_rollout_async else policy
# Create dataloader for online training.
concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset])
sampler_weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler = torch.utils.data.WeightedRandomSampler(
sampler_weights,
@@ -506,20 +395,22 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
dataloader = torch.utils.data.DataLoader(
concat_dataset,
batch_size=cfg.training.batch_size,
num_workers=cfg.training.num_workers,
batch_size=cfg.batch_size,
num_workers=cfg.num_workers,
sampler=sampler,
pin_memory=device.type != "cpu",
drop_last=True,
)
dl_iter = cycle(dataloader)
# Lock and thread pool executor for asynchronous online rollouts. When asynchronous mode is disabled,
# these are still used but effectively do nothing.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
if cfg.online.do_rollout_async:
# Lock and thread pool executor for asynchronous online rollouts.
lock = Lock()
# Note: 1 worker because we only ever want to run one set of online rollouts at a time. Batch
# parallelization of rollouts is handled within the job.
executor = ThreadPoolExecutor(max_workers=1)
else:
lock = None
online_step = 0
online_rollout_s = 0 # time take to do online rollout
@@ -527,10 +418,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Time taken waiting for the online buffer to finish being updated. This is relevant when using the async
# online rollout option.
await_update_online_buffer_s = 0
rollout_start_seed = cfg.training.online_env_seed
rollout_start_seed = cfg.online.env_seed
while True:
if online_step == cfg.training.online_steps:
if online_step == cfg.online.steps:
break
if online_step == 0:
@@ -538,25 +429,33 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
def sample_trajectory_and_update_buffer():
nonlocal rollout_start_seed
with lock:
with lock if lock is not None else nullcontext():
online_rollout_policy.load_state_dict(policy.state_dict())
online_rollout_policy.eval()
start_rollout_time = time.perf_counter()
with torch.no_grad():
eval_info = eval_policy(
online_env,
online_rollout_policy,
n_episodes=cfg.training.online_rollout_n_episodes,
max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes),
n_episodes=cfg.online.rollout_n_episodes,
max_episodes_rendered=min(10, cfg.online.rollout_n_episodes),
videos_dir=logger.log_dir / "online_rollout_videos",
return_episode_data=True,
start_seed=(
rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000
),
start_seed=(rollout_start_seed := (rollout_start_seed + cfg.batch_size) % 1000000),
)
online_rollout_s = time.perf_counter() - start_rollout_time
with lock:
if len(offline_dataset.meta.tasks) > 1:
raise NotImplementedError("Add support for multi task.")
# Hack to add a task to the online_dataset (0 is the first task of the offline_dataset)
total_num_frames = eval_info["episodes"]["index"].shape[0]
eval_info["episodes"]["task_index"] = torch.tensor([0] * total_num_frames, dtype=torch.int64)
with lock if lock is not None else nullcontext():
start_update_buffer_time = time.perf_counter()
online_dataset.add_data(eval_info["episodes"])
@@ -566,12 +465,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# Update the sampling weights.
sampler.weights = compute_sampler_weights(
offline_dataset,
offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0),
offline_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0),
online_dataset=online_dataset,
# +1 because online rollouts return an extra frame for the "final observation". Note: we don't have
# this final observation in the offline datasets, but we might add them in future.
online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.training.online_sampling_ratio,
online_drop_n_last_frames=getattr(cfg.policy, "drop_n_last_frames", 0) + 1,
online_sampling_ratio=cfg.online.sampling_ratio,
)
sampler.num_frames = len(concat_dataset)
@@ -579,36 +478,34 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
return online_rollout_s, update_online_buffer_s
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if (
not cfg.training.do_online_rollout_async
or len(online_dataset) <= cfg.training.online_buffer_seed_size
):
online_rollout_s, update_online_buffer_s = future.result()
if lock is None:
online_rollout_s, update_online_buffer_s = sample_trajectory_and_update_buffer()
else:
future = executor.submit(sample_trajectory_and_update_buffer)
# If we aren't doing async rollouts, or if we haven't yet gotten enough examples in our buffer, wait
# here until the rollout and buffer update is done, before proceeding to the policy update steps.
if len(online_dataset) <= cfg.online.buffer_seed_size:
online_rollout_s, update_online_buffer_s = future.result()
if len(online_dataset) <= cfg.training.online_buffer_seed_size:
logging.info(
f"Seeding online buffer: {len(online_dataset)}/{cfg.training.online_buffer_seed_size}"
)
if len(online_dataset) <= cfg.online.buffer_seed_size:
logging.info(f"Seeding online buffer: {len(online_dataset)}/{cfg.online.buffer_seed_size}")
continue
policy.train()
for _ in range(cfg.training.online_steps_between_rollouts):
with lock:
for _ in range(cfg.online.steps_between_rollouts):
with lock if lock is not None else nullcontext():
start_time = time.perf_counter()
batch = next(dl_iter)
dataloading_s = time.perf_counter() - start_time
for key in batch:
batch[key] = batch[key].to(cfg.device, non_blocking=True)
batch[key] = batch[key].to(device, non_blocking=True)
train_info = update_policy(
policy,
batch,
optimizer,
cfg.training.grad_clip_norm,
cfg.optimizer.grad_clip_norm,
grad_scaler=grad_scaler,
lr_scheduler=lr_scheduler,
use_amp=cfg.use_amp,
@@ -619,10 +516,10 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
train_info["online_rollout_s"] = online_rollout_s
train_info["update_online_buffer_s"] = update_online_buffer_s
train_info["await_update_online_buffer_s"] = await_update_online_buffer_s
with lock:
with lock if lock is not None else nullcontext():
train_info["online_buffer_size"] = len(online_dataset)
if step % cfg.training.log_freq == 0:
if step % cfg.log_freq == 0:
log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True)
# Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed,
@@ -634,12 +531,12 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# If we're doing async rollouts, we should now wait until we've completed them before proceeding
# to do the next batch of rollouts.
if future.running():
if cfg.online.do_rollout_async and future.running():
start = time.perf_counter()
online_rollout_s, update_online_buffer_s = future.result()
await_update_online_buffer_s = time.perf_counter() - start
if online_step >= cfg.training.online_steps:
if online_step >= cfg.online.steps:
break
if eval_env:
@@ -647,23 +544,6 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("End of training")
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def train_cli(cfg: dict):
train(
cfg,
out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir,
job_name=hydra.core.hydra_config.HydraConfig.get().job.name,
)
def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"):
from hydra import compose, initialize
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=config_path)
cfg = compose(config_name=config_name)
train(cfg, out_dir=out_dir, job_name=job_name)
if __name__ == "__main__":
train_cli()
init_logging()
train()

View File

@@ -18,142 +18,102 @@
This script will generate examples of transformed images as they are output by LeRobot dataset.
Additionally, each individual transform can be visualized separately as well as examples of combined transforms
--- Usage Examples ---
Increase hue jitter
```
Example:
```bash
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.hue.min_max="[-0.25,0.25]"
--repo_id=lerobot/pusht \
--episodes='[0]' \
--image_transforms.enable=True
```
Increase brightness & brightness weight
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.brightness.weight=10.0 \
training.image_transforms.brightness.min_max="[1.0,2.0]"
```
Blur images and disable saturation & hue
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.sharpness.weight=10.0 \
training.image_transforms.sharpness.min_max="[0.0,1.0]" \
training.image_transforms.saturation.weight=0.0 \
training.image_transforms.hue.weight=0.0
```
Use all transforms with random order
```
python lerobot/scripts/visualize_image_transforms.py \
dataset_repo_id=lerobot/aloha_mobile_shrimp \
training.image_transforms.max_num_transforms=5 \
training.image_transforms.random_order=true
```
"""
import logging
from copy import deepcopy
from dataclasses import replace
from pathlib import Path
import hydra
import draccus
from torchvision.transforms import ToPILImage
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.datasets.transforms import (
ImageTransforms,
ImageTransformsConfig,
make_transform_from_config,
)
from lerobot.configs.default import DatasetConfig
OUTPUT_DIR = Path("outputs/image_transforms")
to_pil = ToPILImage()
def save_config_all_transforms(cfg, original_frame, output_dir, n_examples):
tf = get_image_transforms(
brightness_weight=cfg.brightness.weight,
brightness_min_max=cfg.brightness.min_max,
contrast_weight=cfg.contrast.weight,
contrast_min_max=cfg.contrast.min_max,
saturation_weight=cfg.saturation.weight,
saturation_min_max=cfg.saturation.min_max,
hue_weight=cfg.hue.weight,
hue_min_max=cfg.hue.min_max,
sharpness_weight=cfg.sharpness.weight,
sharpness_min_max=cfg.sharpness.min_max,
max_num_transforms=cfg.max_num_transforms,
random_order=cfg.random_order,
)
def save_all_transforms(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
output_dir_all = output_dir / "all"
output_dir_all.mkdir(parents=True, exist_ok=True)
tfs = ImageTransforms(cfg)
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
transformed_frame = tfs(original_frame)
to_pil(transformed_frame).save(output_dir_all / f"{i}.png", quality=100)
print("Combined transforms examples saved to:")
print(f" {output_dir_all}")
def save_config_single_transforms(cfg, original_frame, output_dir, n_examples):
transforms = [
"brightness",
"contrast",
"saturation",
"hue",
"sharpness",
]
def save_each_transform(cfg: ImageTransformsConfig, original_frame, output_dir, n_examples):
if not cfg.enable:
logging.warning(
"No single transforms will be saved, because `image_transforms.enable=False`. To enable, set `enable` to True in `ImageTransformsConfig` or in the command line with `--image_transforms.enable=True`."
)
return
print("Individual transforms examples saved to:")
for transform in transforms:
# Apply one transformation with random value in min_max range
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": cfg[f"{transform}"].min_max,
}
tf = get_image_transforms(**kwargs)
output_dir_single = output_dir / f"{transform}"
for tf_name, tf_cfg in cfg.tfs.items():
# Apply a few transformation with random value in min_max range
output_dir_single = output_dir / tf_name
output_dir_single.mkdir(parents=True, exist_ok=True)
tf = make_transform_from_config(tf_cfg)
for i in range(1, n_examples + 1):
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / f"{i}.png", quality=100)
# Apply min transformation
min_value, max_value = cfg[f"{transform}"].min_max
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (min_value, min_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "min.png", quality=100)
# Apply min, max, average transformations
tf_cfg_kwgs_min = deepcopy(tf_cfg.kwargs)
tf_cfg_kwgs_max = deepcopy(tf_cfg.kwargs)
tf_cfg_kwgs_avg = deepcopy(tf_cfg.kwargs)
# Apply max transformation
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (max_value, max_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "max.png", quality=100)
for key, (min_, max_) in tf_cfg.kwargs.items():
avg = (min_ + max_) / 2
tf_cfg_kwgs_min[key] = [min_, min_]
tf_cfg_kwgs_max[key] = [max_, max_]
tf_cfg_kwgs_avg[key] = [avg, avg]
# Apply mean transformation
mean_value = (min_value + max_value) / 2
kwargs = {
f"{transform}_weight": cfg[f"{transform}"].weight,
f"{transform}_min_max": (mean_value, mean_value),
}
tf = get_image_transforms(**kwargs)
transformed_frame = tf(original_frame)
to_pil(transformed_frame).save(output_dir_single / "mean.png", quality=100)
tf_min = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_min}))
tf_max = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_max}))
tf_avg = make_transform_from_config(replace(tf_cfg, **{"kwargs": tf_cfg_kwgs_avg}))
tf_frame_min = tf_min(original_frame)
tf_frame_max = tf_max(original_frame)
tf_frame_avg = tf_avg(original_frame)
to_pil(tf_frame_min).save(output_dir_single / "min.png", quality=100)
to_pil(tf_frame_max).save(output_dir_single / "max.png", quality=100)
to_pil(tf_frame_avg).save(output_dir_single / "mean.png", quality=100)
print(f" {output_dir_single}")
def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
dataset = LeRobotDataset(cfg.dataset_repo_id)
@draccus.wrap()
def visualize_image_transforms(cfg: DatasetConfig, output_dir: Path = OUTPUT_DIR, n_examples: int = 5):
dataset = LeRobotDataset(
repo_id=cfg.repo_id,
episodes=cfg.episodes,
local_files_only=cfg.local_files_only,
video_backend=cfg.video_backend,
)
output_dir = output_dir / cfg.dataset_repo_id.split("/")[-1]
output_dir = output_dir / cfg.repo_id.split("/")[-1]
output_dir.mkdir(parents=True, exist_ok=True)
# Get 1st frame from 1st camera of 1st episode
@@ -162,14 +122,9 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5):
print("\nOriginal frame saved to:")
print(f" {output_dir / 'original_frame.png'}.")
save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples)
@hydra.main(version_base="1.2", config_name="default", config_path="../configs")
def visualize_transforms_cli(cfg):
visualize_transforms(cfg, output_dir=OUTPUT_DIR)
save_all_transforms(cfg.image_transforms, original_frame, output_dir, n_examples)
save_each_transform(cfg.image_transforms, original_frame, output_dir, n_examples)
if __name__ == "__main__":
visualize_transforms_cli()
visualize_image_transforms()