Compare commits
3 Commits
pre-commit
...
thom-propo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1e47202c0 | ||
|
|
24821fee24 | ||
|
|
4751642ace |
@@ -50,6 +50,6 @@ for offline_step in trange(cfg.offline_steps):
|
|||||||
print(train_info)
|
print(train_info)
|
||||||
|
|
||||||
# Save the policy, configuration, and normalization stats for later use.
|
# Save the policy, configuration, and normalization stats for later use.
|
||||||
policy.save(output_directory / "model.pt")
|
policy.save_pretrained(output_directory / "model.pt")
|
||||||
OmegaConf.save(cfg, output_directory / "config.yaml")
|
OmegaConf.save(cfg, output_directory / "config.yaml")
|
||||||
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
|
||||||
def log_output_dir(out_dir):
|
def log_output_dir(out_dir):
|
||||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
|
||||||
@@ -67,11 +68,11 @@ class Logger:
|
|||||||
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
|
||||||
self._wandb = wandb
|
self._wandb = wandb
|
||||||
|
|
||||||
def save_model(self, policy, identifier):
|
def save_model(self, policy: AbstractPolicy, identifier):
|
||||||
if self._save_model:
|
if self._save_model:
|
||||||
self._model_dir.mkdir(parents=True, exist_ok=True)
|
self._model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
fp = self._model_dir / f"{str(identifier)}.pt"
|
fp = self._model_dir / f"{str(identifier)}.pt"
|
||||||
policy.save(fp)
|
policy.save_pretrained(fp)
|
||||||
if self._wandb and not self._disable_wandb_artifact:
|
if self._wandb and not self._disable_wandb_artifact:
|
||||||
# note wandb artifact does not accept ":" in its name
|
# note wandb artifact does not accept ":" in its name
|
||||||
artifact = self._wandb.Artifact(
|
artifact = self._wandb.Artifact(
|
||||||
|
|||||||
@@ -2,14 +2,25 @@ from collections import deque
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from huggingface_hub import PyTorchModelHubMixin
|
||||||
|
|
||||||
|
|
||||||
class AbstractPolicy(nn.Module):
|
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
|
||||||
"""Base policy which all policies should be derived from.
|
"""Base policy which all policies should be derived from.
|
||||||
|
|
||||||
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
|
||||||
documentation for more information.
|
documentation for more information.
|
||||||
|
|
||||||
|
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
|
||||||
|
# Save policy weights to local directory
|
||||||
|
>>> policy.save_pretrained("my-awesome-policy")
|
||||||
|
|
||||||
|
# Push policy weights to the Hub
|
||||||
|
>>> policy.push_to_hub("my-awesome-policy")
|
||||||
|
|
||||||
|
# Download and initialize policy from the Hub
|
||||||
|
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
|
||||||
1. set the required class attributes:
|
1. set the required class attributes:
|
||||||
@@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module):
|
|||||||
|
|
||||||
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
name: str | None = None # same name should be used to instantiate the policy in factory.py
|
||||||
|
|
||||||
def __init__(self, n_action_steps: int | None):
|
def __init__(self, n_action_steps: int | None = None):
|
||||||
"""
|
"""
|
||||||
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
|
||||||
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
|
||||||
@@ -37,10 +48,10 @@ class AbstractPolicy(nn.Module):
|
|||||||
"""One step of the policy's learning algorithm."""
|
"""One step of the policy's learning algorithm."""
|
||||||
raise NotImplementedError("Abstract method")
|
raise NotImplementedError("Abstract method")
|
||||||
|
|
||||||
def save(self, fp):
|
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
def load(self, fp):
|
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
|
||||||
d = torch.load(fp)
|
d = torch.load(fp)
|
||||||
self.load_state_dict(d)
|
self.load_state_dict(d)
|
||||||
|
|
||||||
|
|||||||
@@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
|
|||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
def load(self, fp):
|
def load(self, fp, device=None):
|
||||||
d = torch.load(fp)
|
d = torch.load(fp, map_location=device)
|
||||||
self.load_state_dict(d)
|
self.load_state_dict(d)
|
||||||
|
|
||||||
def compute_loss(self, batch):
|
def compute_loss(self, batch):
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
|
|||||||
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
Then in that same runtime you can also save the weights with the new aligned state_dict:
|
||||||
|
|
||||||
```
|
```
|
||||||
policy.save("weights.pt")
|
policy.save_pretrained("my-policy")
|
||||||
```
|
```
|
||||||
|
|
||||||
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.
|
||||||
|
|||||||
@@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
|
|||||||
def save(self, fp):
|
def save(self, fp):
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
def load(self, fp):
|
def load(self, fp, device=None):
|
||||||
d = torch.load(fp)
|
d = torch.load(fp, map_location=device)
|
||||||
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
assert all(k.startswith("ema_diffusion.") for k in missing_keys)
|
||||||
|
|||||||
@@ -1,35 +1,53 @@
|
|||||||
def make_policy(cfg):
|
""" Factory for policies
|
||||||
|
"""
|
||||||
|
|
||||||
|
from lerobot.common.policies.abstract import AbstractPolicy
|
||||||
|
|
||||||
|
|
||||||
|
def make_policy(cfg: dict) -> AbstractPolicy:
|
||||||
|
""" Instantiate a policy from the configuration.
|
||||||
|
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The configuration (DictConfig)
|
||||||
|
|
||||||
|
"""
|
||||||
|
policy_kwargs = {}
|
||||||
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
|
||||||
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
|
||||||
|
|
||||||
if cfg.policy.name == "tdmpc":
|
if cfg.policy.name == "tdmpc":
|
||||||
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
|
||||||
|
|
||||||
policy = TDMPCPolicy(cfg.policy, cfg.device)
|
policy_cls = TDMPCPolicy
|
||||||
|
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
|
||||||
elif cfg.policy.name == "diffusion":
|
elif cfg.policy.name == "diffusion":
|
||||||
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
|
||||||
|
|
||||||
policy = DiffusionPolicy(
|
policy_cls = DiffusionPolicy
|
||||||
cfg=cfg.policy,
|
policy_kwargs = {
|
||||||
cfg_device=cfg.device,
|
"cfg": cfg.policy,
|
||||||
cfg_noise_scheduler=cfg.noise_scheduler,
|
"cfg_device": cfg.device,
|
||||||
cfg_rgb_model=cfg.rgb_model,
|
"cfg_noise_scheduler": cfg.noise_scheduler,
|
||||||
cfg_obs_encoder=cfg.obs_encoder,
|
"cfg_rgb_model": cfg.rgb_model,
|
||||||
cfg_optimizer=cfg.optimizer,
|
"cfg_obs_encoder": cfg.obs_encoder,
|
||||||
cfg_ema=cfg.ema,
|
"cfg_optimizer": cfg.optimizer,
|
||||||
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
|
"cfg_ema": cfg.ema,
|
||||||
|
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
|
||||||
**cfg.policy,
|
**cfg.policy,
|
||||||
)
|
}
|
||||||
elif cfg.policy.name == "act":
|
elif cfg.policy.name == "act":
|
||||||
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
|
||||||
|
|
||||||
policy = ActionChunkingTransformerPolicy(
|
policy_cls = ActionChunkingTransformerPolicy
|
||||||
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
|
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(cfg.policy.name)
|
raise ValueError(cfg.policy.name)
|
||||||
|
|
||||||
if cfg.policy.pretrained_model_path:
|
if cfg.policy.pretrained_model_path:
|
||||||
|
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
|
||||||
|
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
|
||||||
|
|
||||||
# TODO(rcadene): hack for old pretrained models from fowm
|
# TODO(rcadene): hack for old pretrained models from fowm
|
||||||
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
|
||||||
if "offline" in cfg.pretrained_model_path:
|
if "offline" in cfg.pretrained_model_path:
|
||||||
@@ -38,6 +56,5 @@ def make_policy(cfg):
|
|||||||
policy.step[0] = 100000
|
policy.step[0] = 100000
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
policy.load(cfg.policy.pretrained_model_path)
|
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|||||||
@@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy):
|
|||||||
"""Save state dict of TOLD model to filepath."""
|
"""Save state dict of TOLD model to filepath."""
|
||||||
torch.save(self.state_dict(), fp)
|
torch.save(self.state_dict(), fp)
|
||||||
|
|
||||||
def load(self, fp):
|
def load(self, fp, device=None):
|
||||||
"""Load a saved state dict from filepath into current agent."""
|
"""Load a saved state dict from filepath into current agent."""
|
||||||
d = torch.load(fp)
|
d = torch.load(fp, map_location=device)
|
||||||
self.model.load_state_dict(d["model"])
|
self.model.load_state_dict(d["model"])
|
||||||
self.model_target.load_state_dict(d["model_target"])
|
self.model_target.load_state_dict(d["model_target"])
|
||||||
|
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
from typing import Tuple, Union
|
||||||
from datetime import datetime as dt
|
from datetime import datetime as dt
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -66,7 +67,19 @@ def eval_policy(
|
|||||||
video_dir: Path = None,
|
video_dir: Path = None,
|
||||||
fps: int = 15,
|
fps: int = 15,
|
||||||
return_first_video: bool = False,
|
return_first_video: bool = False,
|
||||||
):
|
) -> Union[dict, Tuple[dict, torch.Tensor]]:
|
||||||
|
""" Evaluate a policy on an environment by running rollouts and computing metrics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env: The environment to evaluate.
|
||||||
|
policy: The policy to evaluate.
|
||||||
|
num_episodes: The number of episodes to evaluate.
|
||||||
|
max_steps: The maximum number of steps per episode.
|
||||||
|
save_video: Whether to save videos of the evaluation episodes.
|
||||||
|
video_dir: The directory to save the videos.
|
||||||
|
fps: The frames per second for the videos.
|
||||||
|
return_first_video: Whether to return the first video as a tensor.
|
||||||
|
"""
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
policy.eval()
|
policy.eval()
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@@ -145,7 +158,7 @@ def eval_policy(
|
|||||||
for thread in threads:
|
for thread in threads:
|
||||||
thread.join()
|
thread.join()
|
||||||
|
|
||||||
info = {
|
info = { # TODO: change to dataclass
|
||||||
"per_episode": [
|
"per_episode": [
|
||||||
{
|
{
|
||||||
"episode_ix": i,
|
"episode_ix": i,
|
||||||
@@ -178,6 +191,13 @@ def eval_policy(
|
|||||||
|
|
||||||
|
|
||||||
def eval(cfg: dict, out_dir=None, stats_path=None):
|
def eval(cfg: dict, out_dir=None, stats_path=None):
|
||||||
|
""" Evaluate a policy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The configuration (DictConfig).
|
||||||
|
out_dir: The directory to save the evaluation results (JSON file and videos)
|
||||||
|
stats_path: The path to the stats file.
|
||||||
|
"""
|
||||||
if out_dir is None:
|
if out_dir is None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@@ -251,7 +271,8 @@ if __name__ == "__main__":
|
|||||||
elif args.hub_id is not None:
|
elif args.hub_id is not None:
|
||||||
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
|
||||||
cfg = init_hydra_config(
|
cfg = init_hydra_config(
|
||||||
folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
folder / "config.yaml", [*args.overrides]
|
||||||
|
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
|
||||||
)
|
)
|
||||||
stats_path = folder / "stats.pth"
|
stats_path = folder / "stats.pth"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user