Compare commits

...

3 Commits

Author SHA1 Message Date
Thomas Wolf
a1e47202c0 update 2024-04-03 07:46:17 +02:00
Thomas Wolf
24821fee24 update 2024-04-02 22:49:16 +02:00
Thomas Wolf
4751642ace adding docstring and from_pretrained/save_pretrained 2024-04-02 22:45:21 +02:00
9 changed files with 83 additions and 33 deletions

View File

@@ -50,6 +50,6 @@ for offline_step in trange(cfg.offline_steps):
print(train_info) print(train_info)
# Save the policy, configuration, and normalization stats for later use. # Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt") policy.save_pretrained(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml") OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from omegaconf import OmegaConf from omegaconf import OmegaConf
from termcolor import colored from termcolor import colored
from lerobot.common.policies.abstract import AbstractPolicy
def log_output_dir(out_dir): def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}") logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@@ -67,11 +68,11 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb self._wandb = wandb
def save_model(self, policy, identifier): def save_model(self, policy: AbstractPolicy, identifier):
if self._save_model: if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True) self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt" fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp) policy.save_pretrained(fp)
if self._wandb and not self._disable_wandb_artifact: if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name # note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact( artifact = self._wandb.Artifact(

View File

@@ -2,14 +2,25 @@ from collections import deque
import torch import torch
from torch import Tensor, nn from torch import Tensor, nn
from huggingface_hub import PyTorchModelHubMixin
class AbstractPolicy(nn.Module): class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
"""Base policy which all policies should be derived from. """Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information. documentation for more information.
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
# Save policy weights to local directory
>>> policy.save_pretrained("my-awesome-policy")
# Push policy weights to the Hub
>>> policy.push_to_hub("my-awesome-policy")
# Download and initialize policy from the Hub
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
Note: Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to: When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes: 1. set the required class attributes:
@@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module):
name: str | None = None # same name should be used to instantiate the policy in factory.py name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None): def __init__(self, n_action_steps: int | None = None):
""" """
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
@@ -37,10 +48,10 @@ class AbstractPolicy(nn.Module):
"""One step of the policy's learning algorithm.""" """One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method") raise NotImplementedError("Abstract method")
def save(self, fp): def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)
def load(self, fp): def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
d = torch.load(fp) d = torch.load(fp)
self.load_state_dict(d) self.load_state_dict(d)

View File

@@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
def save(self, fp): def save(self, fp):
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)
def load(self, fp): def load(self, fp, device=None):
d = torch.load(fp) d = torch.load(fp, map_location=device)
self.load_state_dict(d) self.load_state_dict(d)
def compute_loss(self, batch): def compute_loss(self, batch):

View File

@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
Then in that same runtime you can also save the weights with the new aligned state_dict: Then in that same runtime you can also save the weights with the new aligned state_dict:
``` ```
policy.save("weights.pt") policy.save_pretrained("my-policy")
``` ```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint. Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.

View File

@@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
def save(self, fp): def save(self, fp):
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)
def load(self, fp): def load(self, fp, device=None):
d = torch.load(fp) d = torch.load(fp, map_location=device)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False) missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0: if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys) assert all(k.startswith("ema_diffusion.") for k in missing_keys)

View File

@@ -1,35 +1,53 @@
def make_policy(cfg): """ Factory for policies
"""
from lerobot.common.policies.abstract import AbstractPolicy
def make_policy(cfg: dict) -> AbstractPolicy:
""" Instantiate a policy from the configuration.
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
Args:
cfg: The configuration (DictConfig)
"""
policy_kwargs = {}
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1: if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.") raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc": if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(cfg.policy, cfg.device) policy_cls = TDMPCPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
elif cfg.policy.name == "diffusion": elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy from lerobot.common.policies.diffusion.policy import DiffusionPolicy
policy = DiffusionPolicy( policy_cls = DiffusionPolicy
cfg=cfg.policy, policy_kwargs = {
cfg_device=cfg.device, "cfg": cfg.policy,
cfg_noise_scheduler=cfg.noise_scheduler, "cfg_device": cfg.device,
cfg_rgb_model=cfg.rgb_model, "cfg_noise_scheduler": cfg.noise_scheduler,
cfg_obs_encoder=cfg.obs_encoder, "cfg_rgb_model": cfg.rgb_model,
cfg_optimizer=cfg.optimizer, "cfg_obs_encoder": cfg.obs_encoder,
cfg_ema=cfg.ema, "cfg_optimizer": cfg.optimizer,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, "cfg_ema": cfg.ema,
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy, **cfg.policy,
) }
elif cfg.policy.name == "act": elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy( policy_cls = ActionChunkingTransformerPolicy
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
)
else: else:
raise ValueError(cfg.policy.name) raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path: if cfg.policy.pretrained_model_path:
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
# TODO(rcadene): hack for old pretrained models from fowm # TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path: if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path: if "offline" in cfg.pretrained_model_path:
@@ -38,6 +56,5 @@ def make_policy(cfg):
policy.step[0] = 100000 policy.step[0] = 100000
else: else:
raise NotImplementedError() raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
return policy return policy

View File

@@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy):
"""Save state dict of TOLD model to filepath.""" """Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp) torch.save(self.state_dict(), fp)
def load(self, fp): def load(self, fp, device=None):
"""Load a saved state dict from filepath into current agent.""" """Load a saved state dict from filepath into current agent."""
d = torch.load(fp) d = torch.load(fp, map_location=device)
self.model.load_state_dict(d["model"]) self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"]) self.model_target.load_state_dict(d["model_target"])

View File

@@ -32,6 +32,7 @@ import json
import logging import logging
import threading import threading
import time import time
from typing import Tuple, Union
from datetime import datetime as dt from datetime import datetime as dt
from pathlib import Path from pathlib import Path
@@ -66,7 +67,19 @@ def eval_policy(
video_dir: Path = None, video_dir: Path = None,
fps: int = 15, fps: int = 15,
return_first_video: bool = False, return_first_video: bool = False,
): ) -> Union[dict, Tuple[dict, torch.Tensor]]:
""" Evaluate a policy on an environment by running rollouts and computing metrics.
Args:
env: The environment to evaluate.
policy: The policy to evaluate.
num_episodes: The number of episodes to evaluate.
max_steps: The maximum number of steps per episode.
save_video: Whether to save videos of the evaluation episodes.
video_dir: The directory to save the videos.
fps: The frames per second for the videos.
return_first_video: Whether to return the first video as a tensor.
"""
if policy is not None: if policy is not None:
policy.eval() policy.eval()
start = time.time() start = time.time()
@@ -145,7 +158,7 @@ def eval_policy(
for thread in threads: for thread in threads:
thread.join() thread.join()
info = { info = { # TODO: change to dataclass
"per_episode": [ "per_episode": [
{ {
"episode_ix": i, "episode_ix": i,
@@ -178,6 +191,13 @@ def eval_policy(
def eval(cfg: dict, out_dir=None, stats_path=None): def eval(cfg: dict, out_dir=None, stats_path=None):
""" Evaluate a policy.
Args:
cfg: The configuration (DictConfig).
out_dir: The directory to save the evaluation results (JSON file and videos)
stats_path: The path to the stats file.
"""
if out_dir is None: if out_dir is None:
raise NotImplementedError() raise NotImplementedError()
@@ -251,7 +271,8 @@ if __name__ == "__main__":
elif args.hub_id is not None: elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision)) folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config( cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] folder / "config.yaml", [*args.overrides]
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
) )
stats_path = folder / "stats.pth" stats_path = folder / "stats.pth"