finish examples 2 and 3
This commit is contained in:
@@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os.path as osp
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
import einops
|
||||
import hydra
|
||||
import imageio
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env
|
||||
from lerobot.common.logger import log_output_dir
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
|
||||
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
@@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
log_output_dir(out_dir)
|
||||
|
||||
logging.info("Making transforms.")
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
|
||||
# TODO(alexander-soare): Completely decouple datasets from evaluation.
|
||||
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None)
|
||||
|
||||
logging.info("Making environment.")
|
||||
env = make_env(cfg, transform=offline_buffer.transform)
|
||||
@@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
def _relative_path_between(path1: Path, path2: Path) -> Path:
|
||||
"""Returns path1 relative to path2."""
|
||||
path1 = path1.absolute()
|
||||
path2 = path2.absolute()
|
||||
try:
|
||||
return path1.relative_to(path2)
|
||||
except ValueError: # most likely because path1 is not a subpath of path2
|
||||
common_parts = Path(osp.commonpath([path1, path2])).parts
|
||||
return Path(
|
||||
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
@@ -259,19 +245,14 @@ if __name__ == "__main__":
|
||||
|
||||
if args.config is not None:
|
||||
# Note: For the config_path, Hydra wants a path relative to this script file.
|
||||
hydra.initialize(
|
||||
config_path=str(
|
||||
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
|
||||
)
|
||||
)
|
||||
cfg = hydra.compose(Path(args.config).stem, args.overrides)
|
||||
cfg = init_hydra_config(args.config, args.overrides)
|
||||
# TODO(alexander-soare): Save and load stats in trained model directory.
|
||||
stats_path = None
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision="v1.0"))
|
||||
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
|
||||
cfg = hydra.compose("config", args.overrides)
|
||||
cfg.policy.pretrained_model_path = folder / "model.pt"
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
eval(
|
||||
|
||||
Reference in New Issue
Block a user