Add ability to eval hub model

This commit is contained in:
Alexander Soare
2024-03-22 10:26:55 +00:00
parent b633748987
commit 8720c568d0
4 changed files with 132 additions and 16 deletions

View File

@@ -1,6 +1,36 @@
"""Evaluate a policy on an environment by running rollouts and computing metrics.
The script may be run in one of two ways:
1. By providing the path to a config file with the --config argument.
2. By providing a HuggingFace Hub ID with the --hub-id argument. You may also provide a revision number with the
--revision argument.
In either case, it is possible to override config arguments by adding a list of config.key=value arguments.
Examples:
You have a specific config file to go with trained model weights, and want to run 10 episodes.
```
python lerobot/scripts/eval.py --config PATH/TO/FOLDER/config.yaml \
policy.pretrained_model_path=PATH/TO/FOLDER/weights.pth` eval_episodes=10
```
You have a HuggingFace Hub ID, you know which revision you want, and want to run 10 episodes (note that in this case,
you don't need to specify which weights to use):
```
python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
```
"""
import argparse
import logging
import os.path as osp
import threading
import time
from datetime import datetime as dt
from pathlib import Path
import einops
@@ -9,6 +39,7 @@ import imageio
import numpy as np
import torch
import tqdm
from huggingface_hub import snapshot_download
from tensordict.nn import TensorDictModule
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
@@ -65,8 +96,8 @@ def eval_policy(
callback=maybe_render_frame,
break_when_any_done=env.batch_size[0] == 1,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after this won't
# be included).
# Figure out where in each rollout sequence the first done condition was encountered (results after
# this won't be included).
# Note: this assumes that the shape of the done key is (batch_size, max_steps, 1).
# Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker.
rollout_steps = rollout["next", "done"].shape[1]
@@ -119,12 +150,7 @@ def eval_policy(
return info
@hydra.main(version_base=None, config_name="default", config_path="../configs")
def eval_cli(cfg: dict):
eval(cfg, out_dir=hydra.core.hydra_config.HydraConfig.get().runtime.output_dir)
def eval(cfg: dict, out_dir=None):
def eval(cfg: dict, out_dir=None, stats_path=None):
if out_dir is None:
raise NotImplementedError()
@@ -139,10 +165,10 @@ def eval(cfg: dict, out_dir=None):
log_output_dir(out_dir)
logging.info("make_offline_buffer")
offline_buffer = make_offline_buffer(cfg)
logging.info("Making transforms.")
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
logging.info("make_env")
logging.info("Making environment.")
env = make_env(cfg, transform=offline_buffer.transform)
if cfg.policy.pretrained_model_path:
@@ -170,5 +196,52 @@ def eval(cfg: dict, out_dir=None):
logging.info("End of eval")
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
if __name__ == "__main__":
eval_cli()
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--config", help="Path to a specific yaml config you want to use.")
group.add_argument("--hub-id", help="HuggingFace Hub ID for a pretrained model.")
parser.add_argument("--revision", help="Optionally provide the HuggingFace Hub revision ID.")
parser.add_argument(
"overrides",
nargs="*",
help="Any key=value arguments to override config values (use dots for.nested=overrides)",
)
args = parser.parse_args()
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
hydra.initialize(
config_path=str(
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
)
)
cfg = hydra.compose(Path(args.config).stem, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
cfg = hydra.compose("config", args.overrides)
cfg.policy.pretrained_model_path = folder / "model.pt"
stats_path = folder / "stats.pth"
eval(
cfg,
out_dir=f"eval_outputs/{dt.now().strftime('%Y-%m-%d/%H-%M-%S')}_{cfg.env.name}_{cfg.policy.name}",
)