Simplify configs (#550)

Co-authored-by: Remi <remi.cadene@huggingface.co>
Co-authored-by: HUANG TZU-CHUN <137322177+tc-huang@users.noreply.github.com>
This commit is contained in:
Simon Alibert
2025-01-31 13:57:37 +01:00
committed by GitHub
parent 1ee1acf8ad
commit 3c0a209f9f
119 changed files with 5761 additions and 5466 deletions

View File

@@ -21,67 +21,69 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di
for 10 episodes.
```
python lerobot/scripts/eval.py -p lerobot/diffusion_pusht eval.n_episodes=10
python lerobot/scripts/eval.py \
--policy.path=lerobot/diffusion_pusht \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
```
OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes.
```
python lerobot/scripts/eval.py \
-p outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
eval.n_episodes=10
--policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \
--env.type=pusht \
--eval.batch_size=10 \
--eval.n_episodes=10 \
--use_amp=false \
--device=cuda
```
Note that in both examples, the repo/folder should contain at least `config.json`, `config.yaml` and
`model.safetensors`.
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
Note the formatting for providing the number of episodes. Generally, you may provide any number of arguments
with `qualified.parameter.name=value`. In this case, the parameter eval.n_episodes appears as `n_episodes`
nested under `eval` in the `config.yaml` found at
https://huggingface.co/lerobot/diffusion_pusht/tree/main.
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
import argparse
import json
import logging
import threading
import time
from contextlib import nullcontext
from copy import deepcopy
from datetime import datetime as dt
from dataclasses import asdict
from pathlib import Path
from pprint import pformat
from typing import Callable
import einops
import gymnasium as gym
import numpy as np
import torch
from huggingface_hub import snapshot_download
from huggingface_hub.errors import RepositoryNotFoundError
from huggingface_hub.utils._validators import HFValidationError
from torch import Tensor, nn
from tqdm import trange
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.policies.pretrained import PreTrainedPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.io_utils import write_video
from lerobot.common.utils.utils import (
get_safe_torch_device,
init_hydra_config,
init_logging,
inside_slurm,
set_global_seed,
)
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
def rollout(
env: gym.vector.VectorEnv,
policy: Policy,
policy: PreTrainedPolicy,
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
@@ -208,7 +210,7 @@ def rollout(
def eval_policy(
env: gym.vector.VectorEnv,
policy: torch.nn.Module,
policy: PreTrainedPolicy,
n_episodes: int,
max_episodes_rendered: int = 0,
videos_dir: Path | None = None,
@@ -232,7 +234,9 @@ def eval_policy(
if max_episodes_rendered > 0 and not videos_dir:
raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.")
assert isinstance(policy, Policy)
if not isinstance(policy, PreTrainedPolicy):
raise ValueError(policy)
start = time.time()
policy.eval()
@@ -442,66 +446,43 @@ def _compile_episode_data(
return data_dict
def main(
pretrained_policy_path: Path | None = None,
hydra_cfg_path: str | None = None,
out_dir: str | None = None,
config_overrides: list[str] | None = None,
):
assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None)
if pretrained_policy_path is not None:
hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides)
else:
hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides)
if hydra_cfg.eval.batch_size > hydra_cfg.eval.n_episodes:
raise ValueError(
"The eval batch size is greater than the number of eval episodes "
f"({hydra_cfg.eval.batch_size} > {hydra_cfg.eval.n_episodes}). As a result, {hydra_cfg.eval.batch_size} "
f"eval environments will be instantiated, but only {hydra_cfg.eval.n_episodes} will be used. "
"This might significantly slow down evaluation. To fix this, you should update your command "
f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={hydra_cfg.eval.batch_size}`), "
f"or lower the batch size (e.g. `eval.batch_size={hydra_cfg.eval.n_episodes}`)."
)
if out_dir is None:
out_dir = f"outputs/eval/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{hydra_cfg.env.name}_{hydra_cfg.policy.name}"
@parser.wrap()
def eval(cfg: EvalPipelineConfig):
logging.info(pformat(asdict(cfg)))
# Check device is available
device = get_safe_torch_device(hydra_cfg.device, log=True)
device = get_safe_torch_device(cfg.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
set_global_seed(hydra_cfg.seed)
set_global_seed(cfg.seed)
log_output_dir(out_dir)
log_output_dir(cfg.output_dir)
logging.info("Making environment.")
env = make_env(hydra_cfg)
env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
logging.info("Making policy.")
if hydra_cfg_path is None:
policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path))
else:
# Note: We need the dataset stats to pass to the policy's normalization modules.
policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats)
assert isinstance(policy, nn.Module)
policy = make_policy(
cfg=cfg.policy,
device=device,
env_cfg=cfg.env,
)
policy.eval()
with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext():
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext():
info = eval_policy(
env,
policy,
hydra_cfg.eval.n_episodes,
cfg.eval.n_episodes,
max_episodes_rendered=10,
videos_dir=Path(out_dir) / "videos",
start_seed=hydra_cfg.seed,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
)
print(info["aggregated"])
# Save info
with open(Path(out_dir) / "eval_info.json", "w") as f:
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
json.dump(info, f, indent=2)
env.close()
@@ -509,76 +490,6 @@ def main(
logging.info("End of eval")
def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None):
try:
pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision))
except (HFValidationError, RepositoryNotFoundError) as e:
if isinstance(e, HFValidationError):
error_message = (
"The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID."
)
else:
error_message = (
"The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub."
)
logging.warning(f"{error_message} Treating it as a local directory.")
pretrained_policy_path = Path(pretrained_policy_name_or_path)
if not pretrained_policy_path.is_dir() or not pretrained_policy_path.exists():
raise ValueError(
"The provided pretrained_policy_name_or_path is not a valid/existing Hugging Face Hub "
"repo ID, nor is it an existing local directory."
)
return pretrained_policy_path
if __name__ == "__main__":
init_logging()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-p",
"--pretrained-policy-name-or-path",
help=(
"Either the repo ID of a model hosted on the Hub or a path to a directory containing weights "
"saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch "
"(useful for debugging). This argument is mutually exclusive with `--config`."
),
)
group.add_argument(
"--config",
help=(
"Path to a yaml config you want to use for initializing a policy from scratch (useful for "
"debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)."
),
)
parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.")
parser.add_argument(
"--out-dir",
help=(
"Where to save the evaluation outputs. If not provided, outputs are saved in "
"outputs/eval/{timestamp}_{env_name}_{policy_name}"
),
)
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
if args.pretrained_policy_name_or_path is None:
main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides)
else:
pretrained_policy_path = get_pretrained_policy_path(
args.pretrained_policy_name_or_path, revision=args.revision
)
main(
pretrained_policy_path=pretrained_policy_path,
out_dir=args.out_dir,
config_overrides=args.overrides,
)
eval()