From 70d7b99d0985884e5868aaf66a06cb5438e02aa5 Mon Sep 17 00:00:00 2001 From: Remi Cadene Date: Fri, 22 Mar 2024 00:47:30 +0100 Subject: [PATCH] add pretrained --- examples/pretrained.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 examples/pretrained.py diff --git a/examples/pretrained.py b/examples/pretrained.py new file mode 100644 index 000000000..3d12897d5 --- /dev/null +++ b/examples/pretrained.py @@ -0,0 +1,71 @@ +import logging +from omegaconf import OmegaConf +from pathlib import Path + +from lerobot.scripts.eval import eval_policy +from huggingface_hub import snapshot_download + +import logging +from pathlib import Path + +import torch +from tensordict.nn import TensorDictModule + +from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.envs.factory import make_env +from lerobot.common.logger import log_output_dir +from lerobot.common.policies.factory import make_policy +from lerobot.common.utils import get_safe_torch_device, init_logging, set_seed + +folder = Path(snapshot_download('lerobot/diffusion_policy_pusht_image', revision="v1.0")) +cfg = OmegaConf.load(folder / "config.yaml") +cfg.policy.pretrained_model_path = folder / "model.pt" +cfg.eval_episodes = 1 +cfg.episode_length = 50 +cfg.device = "cpu" + +out_dir = "test" + +if out_dir is None: + raise NotImplementedError() + +init_logging() + +# Check device is available +get_safe_torch_device(cfg.device, log=True) + +torch.backends.cudnn.benchmark = True +torch.backends.cuda.matmul.allow_tf32 = True +set_seed(cfg.seed) + +log_output_dir(out_dir) + +logging.info("make_offline_buffer") +offline_buffer = make_offline_buffer(cfg) + +logging.info("make_env") +env = make_env(cfg, transform=offline_buffer.transform) + +if cfg.policy.pretrained_model_path: + policy = make_policy(cfg) + policy = TensorDictModule( + policy, + in_keys=["observation", "step_count"], + out_keys=["action"], + ) +else: + # when policy is None, rollout a random policy + policy = None + +metrics = eval_policy( + env, + policy=policy, + save_video=True, + video_dir=Path(out_dir) / "eval", + fps=cfg.env.fps, + max_steps=cfg.env.episode_length, + num_episodes=cfg.eval_episodes, +) +print(metrics) + +logging.info("End of eval") \ No newline at end of file