diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 01a4cf76d..7ffdf8733 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -50,6 +50,6 @@ for offline_step in trange(cfg.offline_steps): print(train_info) # Save the policy, configuration, and normalization stats for later use. -policy.save(output_directory / "model.pt") +policy.save_pretrained(output_directory / "model.pt") OmegaConf.save(cfg, output_directory / "config.yaml") torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 6c2cc4f70..41d576dab 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -5,6 +5,7 @@ from pathlib import Path from omegaconf import OmegaConf from termcolor import colored +from lerobot.common.policies.abstract import AbstractPolicy def log_output_dir(out_dir): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") @@ -67,11 +68,11 @@ class Logger: logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb - def save_model(self, policy, identifier): + def save_model(self, policy: AbstractPolicy, identifier): if self._save_model: self._model_dir.mkdir(parents=True, exist_ok=True) fp = self._model_dir / f"{str(identifier)}.pt" - policy.save(fp) + policy.save_pretrained(fp) if self._wandb and not self._disable_wandb_artifact: # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact( diff --git a/lerobot/common/policies/abstract.py b/lerobot/common/policies/abstract.py index 6dc72bef5..34b4ea33e 100644 --- a/lerobot/common/policies/abstract.py +++ b/lerobot/common/policies/abstract.py @@ -2,14 +2,25 @@ from collections import deque import torch from torch import Tensor, nn +from huggingface_hub import PyTorchModelHubMixin -class AbstractPolicy(nn.Module): +class AbstractPolicy(nn.Module, PyTorchModelHubMixin): """Base policy which all policies should be derived from. The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its documentation for more information. + The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory. + # Save policy weights to local directory + >>> policy.save_pretrained("my-awesome-policy") + + # Push policy weights to the Hub + >>> policy.push_to_hub("my-awesome-policy") + + # Download and initialize policy from the Hub + >>> policy = MyPolicy.from_pretrained("username/my-awesome-policy") + Note: When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: 1. set the required class attributes: @@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module): name: str | None = None # same name should be used to instantiate the policy in factory.py - def __init__(self, n_action_steps: int | None): + def __init__(self, n_action_steps: int | None = None): """ n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then diff --git a/lerobot/common/policies/act/policy.py b/lerobot/common/policies/act/policy.py index ae4f73200..4789761d7 100644 --- a/lerobot/common/policies/act/policy.py +++ b/lerobot/common/policies/act/policy.py @@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy): def save(self, fp): torch.save(self.state_dict(), fp) - def load(self, fp): - d = torch.load(fp) + def load(self, fp, device=None): + d = torch.load(fp, map_location=device) self.load_state_dict(d) def compute_loss(self, batch): diff --git a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py index 7719fddea..38714a01c 100644 --- a/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py +++ b/lerobot/common/policies/diffusion/diffusion_unet_image_policy.py @@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0 Then in that same runtime you can also save the weights with the new aligned state_dict: ``` -policy.save("weights.pt") +policy.save_pretrained("my-policy") ``` Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. diff --git a/lerobot/common/policies/diffusion/policy.py b/lerobot/common/policies/diffusion/policy.py index 82f39b285..7b89e3a70 100644 --- a/lerobot/common/policies/diffusion/policy.py +++ b/lerobot/common/policies/diffusion/policy.py @@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy): def save(self, fp): torch.save(self.state_dict(), fp) - def load(self, fp): - d = torch.load(fp) + def load(self, fp, device=None): + d = torch.load(fp, map_location=device) missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) if len(missing_keys) > 0: assert all(k.startswith("ema_diffusion.") for k in missing_keys) diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 934f09629..1540ea5a4 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -1,35 +1,53 @@ -def make_policy(cfg): +""" Factory for policies +""" + +from lerobot.common.policies.abstract import AbstractPolicy + + +def make_policy(cfg: dict) -> AbstractPolicy: + """ Instantiate a policy from the configuration. + Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act. + + Args: + cfg: The configuration (DictConfig) + + """ + policy_kwargs = {} if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") if cfg.policy.name == "tdmpc": from lerobot.common.policies.tdmpc.policy import TDMPCPolicy - policy = TDMPCPolicy(cfg.policy, cfg.device) + policy_cls = TDMPCPolicy + policy_kwargs = {"cfg": cfg.policy, "device": cfg.device} elif cfg.policy.name == "diffusion": from lerobot.common.policies.diffusion.policy import DiffusionPolicy - policy = DiffusionPolicy( - cfg=cfg.policy, - cfg_device=cfg.device, - cfg_noise_scheduler=cfg.noise_scheduler, - cfg_rgb_model=cfg.rgb_model, - cfg_obs_encoder=cfg.obs_encoder, - cfg_optimizer=cfg.optimizer, - cfg_ema=cfg.ema, - n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, + policy_cls = DiffusionPolicy + policy_kwargs = { + "cfg": cfg.policy, + "cfg_device": cfg.device, + "cfg_noise_scheduler": cfg.noise_scheduler, + "cfg_rgb_model": cfg.rgb_model, + "cfg_obs_encoder": cfg.obs_encoder, + "cfg_optimizer": cfg.optimizer, + "cfg_ema": cfg.ema, + "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps, **cfg.policy, - ) + } elif cfg.policy.name == "act": from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy - policy = ActionChunkingTransformerPolicy( - cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps - ) + policy_cls = ActionChunkingTransformerPolicy + policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps} else: raise ValueError(cfg.policy.name) if cfg.policy.pretrained_model_path: + # policy.load(cfg.policy.pretrained_model_path, device=cfg.device) + policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs) + # TODO(rcadene): hack for old pretrained models from fowm if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if "offline" in cfg.pretrained_model_path: @@ -38,6 +56,5 @@ def make_policy(cfg): policy.step[0] = 100000 else: raise NotImplementedError() - policy.load(cfg.policy.pretrained_model_path) return policy diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index 64dcc94dc..9172c536f 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy): """Save state dict of TOLD model to filepath.""" torch.save(self.state_dict(), fp) - def load(self, fp): + def load(self, fp, device=None): """Load a saved state dict from filepath into current agent.""" - d = torch.load(fp) + d = torch.load(fp, map_location=device) self.model.load_state_dict(d["model"]) self.model_target.load_state_dict(d["model_target"]) diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 216769d6c..7383baafa 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -32,6 +32,7 @@ import json import logging import threading import time +from typing import Tuple, Union from datetime import datetime as dt from pathlib import Path @@ -66,7 +67,19 @@ def eval_policy( video_dir: Path = None, fps: int = 15, return_first_video: bool = False, -): +) -> Union[dict, Tuple[dict, torch.Tensor]]: + """ Evaluate a policy on an environment by running rollouts and computing metrics. + + Args: + env: The environment to evaluate. + policy: The policy to evaluate. + num_episodes: The number of episodes to evaluate. + max_steps: The maximum number of steps per episode. + save_video: Whether to save videos of the evaluation episodes. + video_dir: The directory to save the videos. + fps: The frames per second for the videos. + return_first_video: Whether to return the first video as a tensor. + """ if policy is not None: policy.eval() start = time.time() @@ -145,7 +158,7 @@ def eval_policy( for thread in threads: thread.join() - info = { + info = { # TODO: change to dataclass "per_episode": [ { "episode_ix": i, @@ -178,6 +191,13 @@ def eval_policy( def eval(cfg: dict, out_dir=None, stats_path=None): + """ Evaluate a policy. + + Args: + cfg: The configuration (DictConfig). + out_dir: The directory to save the evaluation results (JSON file and videos) + stats_path: The path to the stats file. + """ if out_dir is None: raise NotImplementedError()