Compare commits

...

3 Commits

Author SHA1 Message Date
Thomas Wolf
a1e47202c0 update 2024-04-03 07:46:17 +02:00
Thomas Wolf
24821fee24 update 2024-04-02 22:49:16 +02:00
Thomas Wolf
4751642ace adding docstring and from_pretrained/save_pretrained 2024-04-02 22:45:21 +02:00
9 changed files with 83 additions and 33 deletions

View File

@@ -50,6 +50,6 @@ for offline_step in trange(cfg.offline_steps):
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save(output_directory / "model.pt")
policy.save_pretrained(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from omegaconf import OmegaConf
from termcolor import colored
from lerobot.common.policies.abstract import AbstractPolicy
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@@ -67,11 +68,11 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def save_model(self, policy, identifier):
def save_model(self, policy: AbstractPolicy, identifier):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp)
policy.save_pretrained(fp)
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(

View File

@@ -2,14 +2,25 @@ from collections import deque
import torch
from torch import Tensor, nn
from huggingface_hub import PyTorchModelHubMixin
class AbstractPolicy(nn.Module):
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
# Save policy weights to local directory
>>> policy.save_pretrained("my-awesome-policy")
# Push policy weights to the Hub
>>> policy.push_to_hub("my-awesome-policy")
# Download and initialize policy from the Hub
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
@@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module):
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
def __init__(self, n_action_steps: int | None = None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
@@ -37,10 +48,10 @@ class AbstractPolicy(nn.Module):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
d = torch.load(fp)
self.load_state_dict(d)

View File

@@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
def load(self, fp, device=None):
d = torch.load(fp, map_location=device)
self.load_state_dict(d)
def compute_loss(self, batch):

View File

@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
policy.save_pretrained("my-policy")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.

View File

@@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
def load(self, fp, device=None):
d = torch.load(fp, map_location=device)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)

View File

@@ -1,35 +1,53 @@
def make_policy(cfg):
""" Factory for policies
"""
from lerobot.common.policies.abstract import AbstractPolicy
def make_policy(cfg: dict) -> AbstractPolicy:
""" Instantiate a policy from the configuration.
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
Args:
cfg: The configuration (DictConfig)
"""
policy_kwargs = {}
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(cfg.policy, cfg.device)
policy_cls = TDMPCPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
policy_cls = DiffusionPolicy
policy_kwargs = {
"cfg": cfg.policy,
"cfg_device": cfg.device,
"cfg_noise_scheduler": cfg.noise_scheduler,
"cfg_rgb_model": cfg.rgb_model,
"cfg_obs_encoder": cfg.obs_encoder,
"cfg_optimizer": cfg.optimizer,
"cfg_ema": cfg.ema,
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
}
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
policy_cls = ActionChunkingTransformerPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
else:
raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path:
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
@@ -38,6 +56,5 @@ def make_policy(cfg):
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
return policy

View File

@@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp, device=None):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
d = torch.load(fp, map_location=device)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])

View File

@@ -32,6 +32,7 @@ import json
import logging
import threading
import time
from typing import Tuple, Union
from datetime import datetime as dt
from pathlib import Path
@@ -66,7 +67,19 @@ def eval_policy(
video_dir: Path = None,
fps: int = 15,
return_first_video: bool = False,
):
) -> Union[dict, Tuple[dict, torch.Tensor]]:
""" Evaluate a policy on an environment by running rollouts and computing metrics.
Args:
env: The environment to evaluate.
policy: The policy to evaluate.
num_episodes: The number of episodes to evaluate.
max_steps: The maximum number of steps per episode.
save_video: Whether to save videos of the evaluation episodes.
video_dir: The directory to save the videos.
fps: The frames per second for the videos.
return_first_video: Whether to return the first video as a tensor.
"""
if policy is not None:
policy.eval()
start = time.time()
@@ -145,7 +158,7 @@ def eval_policy(
for thread in threads:
thread.join()
info = {
info = { # TODO: change to dataclass
"per_episode": [
{
"episode_ix": i,
@@ -178,6 +191,13 @@ def eval_policy(
def eval(cfg: dict, out_dir=None, stats_path=None):
""" Evaluate a policy.
Args:
cfg: The configuration (DictConfig).
out_dir: The directory to save the evaluation results (JSON file and videos)
stats_path: The path to the stats file.
"""
if out_dir is None:
raise NotImplementedError()
@@ -251,7 +271,8 @@ if __name__ == "__main__":
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = init_hydra_config(
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
folder / "config.yaml", [*args.overrides]
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"