Compare commits
3 Commits
main
...
thom-propo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1e47202c0 | ||
|
|
24821fee24 | ||
|
|
4751642ace |
@@ -50,6 +50,6 @@ for offline_step in trange(cfg.offline_steps):
|
||||
print(train_info)
|
||||
|
||||
# Save the policy, configuration, and normalization stats for later use.
|
||||
policy.save(output_directory / "model.pt")
|
||||
policy.save_pretrained(output_directory / "model.pt")
|
||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
from omegaconf import OmegaConf
|
||||
from termcolor import colored
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
def log_output_dir(out_dir):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||
@@ -67,11 +68,11 @@ class Logger:
|
||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||
self._wandb = wandb
|
||||
|
||||
def save_model(self, policy, identifier):
|
||||
def save_model(self, policy: AbstractPolicy, identifier):
|
||||
if self._save_model:
|
||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||
policy.save(fp)
|
||||
policy.save_pretrained(fp)
|
||||
if self._wandb and not self._disable_wandb_artifact:
|
||||
# note wandb artifact does not accept ":" in its name
|
||||
artifact = self._wandb.Artifact(
|
||||
|
||||
@@ -2,14 +2,25 @@ from collections import deque
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from huggingface_hub import PyTorchModelHubMixin
|
||||
|
||||
|
||||
class AbstractPolicy(nn.Module):
|
||||
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
|
||||
"""Base policy which all policies should be derived from.
|
||||
|
||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||
documentation for more information.
|
||||
|
||||
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
|
||||
# Save policy weights to local directory
|
||||
>>> policy.save_pretrained("my-awesome-policy")
|
||||
|
||||
# Push policy weights to the Hub
|
||||
>>> policy.push_to_hub("my-awesome-policy")
|
||||
|
||||
# Download and initialize policy from the Hub
|
||||
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
|
||||
|
||||
Note:
|
||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||
1. set the required class attributes:
|
||||
@@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module):
|
||||
|
||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||
|
||||
def __init__(self, n_action_steps: int | None):
|
||||
def __init__(self, n_action_steps: int | None = None):
|
||||
"""
|
||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||
@@ -37,10 +48,10 @@ class AbstractPolicy(nn.Module):
|
||||
"""One step of the policy's learning algorithm."""
|
||||
raise NotImplementedError("Abstract method")
|
||||
|
||||
def save(self, fp):
|
||||
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||
d = torch.load(fp)
|
||||
self.load_state_dict(d)
|
||||
|
||||
|
||||
@@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
def load(self, fp, device=None):
|
||||
d = torch.load(fp, map_location=device)
|
||||
self.load_state_dict(d)
|
||||
|
||||
def compute_loss(self, batch):
|
||||
|
||||
@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
|
||||
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
||||
|
||||
```
|
||||
policy.save("weights.pt")
|
||||
policy.save_pretrained("my-policy")
|
||||
```
|
||||
|
||||
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
||||
|
||||
@@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
|
||||
def save(self, fp):
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
d = torch.load(fp)
|
||||
def load(self, fp, device=None):
|
||||
d = torch.load(fp, map_location=device)
|
||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||
if len(missing_keys) > 0:
|
||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||
|
||||
@@ -1,35 +1,53 @@
|
||||
def make_policy(cfg):
|
||||
""" Factory for policies
|
||||
"""
|
||||
|
||||
from lerobot.common.policies.abstract import AbstractPolicy
|
||||
|
||||
|
||||
def make_policy(cfg: dict) -> AbstractPolicy:
|
||||
""" Instantiate a policy from the configuration.
|
||||
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
|
||||
|
||||
Args:
|
||||
cfg: The configuration (DictConfig)
|
||||
|
||||
"""
|
||||
policy_kwargs = {}
|
||||
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||
|
||||
if cfg.policy.name == "tdmpc":
|
||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||
|
||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
||||
policy_cls = TDMPCPolicy
|
||||
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
|
||||
elif cfg.policy.name == "diffusion":
|
||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||
|
||||
policy = DiffusionPolicy(
|
||||
cfg=cfg.policy,
|
||||
cfg_device=cfg.device,
|
||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
||||
cfg_rgb_model=cfg.rgb_model,
|
||||
cfg_obs_encoder=cfg.obs_encoder,
|
||||
cfg_optimizer=cfg.optimizer,
|
||||
cfg_ema=cfg.ema,
|
||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
||||
policy_cls = DiffusionPolicy
|
||||
policy_kwargs = {
|
||||
"cfg": cfg.policy,
|
||||
"cfg_device": cfg.device,
|
||||
"cfg_noise_scheduler": cfg.noise_scheduler,
|
||||
"cfg_rgb_model": cfg.rgb_model,
|
||||
"cfg_obs_encoder": cfg.obs_encoder,
|
||||
"cfg_optimizer": cfg.optimizer,
|
||||
"cfg_ema": cfg.ema,
|
||||
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
|
||||
**cfg.policy,
|
||||
)
|
||||
}
|
||||
elif cfg.policy.name == "act":
|
||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||
|
||||
policy = ActionChunkingTransformerPolicy(
|
||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
||||
)
|
||||
policy_cls = ActionChunkingTransformerPolicy
|
||||
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
|
||||
else:
|
||||
raise ValueError(cfg.policy.name)
|
||||
|
||||
if cfg.policy.pretrained_model_path:
|
||||
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
|
||||
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
|
||||
|
||||
# TODO(rcadene): hack for old pretrained models from fowm
|
||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||
if "offline" in cfg.pretrained_model_path:
|
||||
@@ -38,6 +56,5 @@ def make_policy(cfg):
|
||||
policy.step[0] = 100000
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
policy.load(cfg.policy.pretrained_model_path)
|
||||
|
||||
return policy
|
||||
|
||||
@@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy):
|
||||
"""Save state dict of TOLD model to filepath."""
|
||||
torch.save(self.state_dict(), fp)
|
||||
|
||||
def load(self, fp):
|
||||
def load(self, fp, device=None):
|
||||
"""Load a saved state dict from filepath into current agent."""
|
||||
d = torch.load(fp)
|
||||
d = torch.load(fp, map_location=device)
|
||||
self.model.load_state_dict(d["model"])
|
||||
self.model_target.load_state_dict(d["model_target"])
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Tuple, Union
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
|
||||
@@ -66,7 +67,19 @@ def eval_policy(
|
||||
video_dir: Path = None,
|
||||
fps: int = 15,
|
||||
return_first_video: bool = False,
|
||||
):
|
||||
) -> Union[dict, Tuple[dict, torch.Tensor]]:
|
||||
""" Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||
|
||||
Args:
|
||||
env: The environment to evaluate.
|
||||
policy: The policy to evaluate.
|
||||
num_episodes: The number of episodes to evaluate.
|
||||
max_steps: The maximum number of steps per episode.
|
||||
save_video: Whether to save videos of the evaluation episodes.
|
||||
video_dir: The directory to save the videos.
|
||||
fps: The frames per second for the videos.
|
||||
return_first_video: Whether to return the first video as a tensor.
|
||||
"""
|
||||
if policy is not None:
|
||||
policy.eval()
|
||||
start = time.time()
|
||||
@@ -145,7 +158,7 @@ def eval_policy(
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
info = {
|
||||
info = { # TODO: change to dataclass
|
||||
"per_episode": [
|
||||
{
|
||||
"episode_ix": i,
|
||||
@@ -178,6 +191,13 @@ def eval_policy(
|
||||
|
||||
|
||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||
""" Evaluate a policy.
|
||||
|
||||
Args:
|
||||
cfg: The configuration (DictConfig).
|
||||
out_dir: The directory to save the evaluation results (JSON file and videos)
|
||||
stats_path: The path to the stats file.
|
||||
"""
|
||||
if out_dir is None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -251,7 +271,8 @@ if __name__ == "__main__":
|
||||
elif args.hub_id is not None:
|
||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||
cfg = init_hydra_config(
|
||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
folder / "config.yaml", [*args.overrides]
|
||||
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||
)
|
||||
stats_path = folder / "stats.pth"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user