Compare commits

..

11 Commits

Author SHA1 Message Date
Thomas Wolf
a1e47202c0 update 2024-04-03 07:46:17 +02:00
Thomas Wolf
24821fee24 update 2024-04-02 22:49:16 +02:00
Thomas Wolf
4751642ace adding docstring and from_pretrained/save_pretrained 2024-04-02 22:45:21 +02:00
Alexander Soare
11cbf1bea1 Merge pull request #53 from alexander-soare/finish_examples
Add examples 2 and 3
2024-04-01 11:52:41 +01:00
Alexander Soare
f1148b8c2d Merge remote-tracking branch 'upstream/main' into finish_examples 2024-04-01 11:31:31 +01:00
Alexander Soare
b7c9c33072 revision 2024-03-27 18:33:48 +00:00
Alexander Soare
120f0aef5c Merge remote-tracking branch 'upstream/main' into finish_examples 2024-03-27 17:52:36 +00:00
Alexander Soare
6cd671040f fix revision 2024-03-27 13:22:14 +00:00
Alexander Soare
011f2d27fe fix tests 2024-03-26 16:40:54 +00:00
Alexander Soare
be4441c7ff update README 2024-03-26 16:28:16 +00:00
Alexander Soare
1ed0110900 finish examples 2 and 3 2024-03-26 16:13:40 +00:00
22 changed files with 312 additions and 170 deletions

View File

@@ -146,11 +146,7 @@ hydra.run.dir=outputs/visualize_dataset/example
### Evaluate a pretrained policy
You can import our environment class, download pretrained policies from the HuggingFace hub, and use our rollout utilities with rendering:
```python
""" Copy pasted from `examples/2_evaluate_pretrained_policy.py`
# TODO
```
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation.
Or you can achieve the same result by executing our script from the command line:
```bash
@@ -160,7 +156,7 @@ eval_episodes=10 \
hydra.run.dir=outputs/eval/example_hub
```
After launching training of your own policy, you can also re-evaluate the checkpoints with:
After training your own policy, you can also re-evaluate the checkpoints with:
```bash
python lerobot/scripts/eval.py \
--config PATH/TO/FOLDER/config.yaml \
@@ -173,19 +169,9 @@ See `python lerobot/scripts/eval.py --help` for more instructions.
### Train your own policy
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub):
```python
""" Copy pasted from `examples/3_train_policy.py`
# TODO
```
You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output!
Or you can achieve the same result by executing our script from the command line:
```bash
python lerobot/scripts/train.py \
hydra.run.dir=outputs/train/example
```
You can easily train any policy on any environment:
In general, you can use our training script to easily train any policy on any environment:
```bash
python lerobot/scripts/train.py \
env=aloha \

View File

@@ -1 +1,39 @@
# TODO
"""
This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local
training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first.
"""
from pathlib import Path
from huggingface_hub import snapshot_download
from lerobot.common.utils import init_hydra_config
from lerobot.scripts.eval import eval
# Get a pretrained policy from the hub.
hub_id = "lerobot/diffusion_policy_pusht_image"
folder = Path(snapshot_download(hub_id))
# OR uncomment the following to evaluate a policy from the local outputs/train folder.
# folder = Path("outputs/train/example_pusht_diffusion")
config_path = folder / "config.yaml"
weights_path = folder / "model.pt"
stats_path = folder / "stats.pth" # normalization stats
# Override some config parameters to do with evaluation.
overrides = [
f"policy.pretrained_model_path={weights_path}",
"eval_episodes=10",
"rollout_batch_size=10",
"device=cuda",
]
# Create a Hydra config.
cfg = init_hydra_config(config_path, overrides)
# Evaluate the policy and save the outputs including metrics and videos.
eval(
cfg,
out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}",
stats_path=stats_path,
)

View File

@@ -1 +1,55 @@
# TODO
"""This scripts demonstrates how to train Diffusion Policy on the PushT environment.
Once you have trained a model with this script, you can try to evaluate it on
examples/2_evaluate_pretrained_policy.py
"""
import os
from pathlib import Path
import torch
from omegaconf import OmegaConf
from tqdm import trange
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
from lerobot.common.utils import init_hydra_config
output_directory = Path("outputs/train/example_pusht_diffusion")
os.makedirs(output_directory, exist_ok=True)
overrides = [
"env=pusht",
"policy=diffusion",
# Adjust as you prefer. 5000 steps are needed to get something worth evaluating.
"offline_steps=5000",
"log_freq=250",
"device=cuda",
]
cfg = init_hydra_config("lerobot/configs/default.yaml", overrides)
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
policy.train()
offline_buffer = make_offline_buffer(cfg)
for offline_step in trange(cfg.offline_steps):
train_info = policy.update(offline_buffer, offline_step)
if offline_step % cfg.log_freq == 0:
print(train_info)
# Save the policy, configuration, and normalization stats for later use.
policy.save_pretrained(output_directory / "model.pt")
OmegaConf.save(cfg, output_directory / "config.yaml")
torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth")

View File

@@ -20,8 +20,6 @@ def make_env(cfg, transform=None):
from lerobot.common.envs.simxarm.env import SimxarmEnv
kwargs["task"] = cfg.env.task
kwargs["visualization_width"] = cfg.env.visualization_width
kwargs["visualization_height"] = cfg.env.visualization_height
clsfunc = SimxarmEnv
elif cfg.env.name == "pusht":
from lerobot.common.envs.pusht.env import PushtEnv

View File

@@ -38,14 +38,7 @@ class SimxarmEnv(AbstractEnv):
device="cpu",
num_prev_obs=0,
num_prev_action=0,
visualization_width=None,
visualization_height=None,
):
self.from_pixels = from_pixels
self.image_size = image_size
self.visualization_width = visualization_width
self.visualization_height = visualization_height
super().__init__(
task=task,
frame_skip=frame_skip,
@@ -69,18 +62,7 @@ class SimxarmEnv(AbstractEnv):
if self.task not in TASKS:
raise ValueError(f"Unknown task {self.task}. Must be one of {list(TASKS.keys())}")
kwargs = (
{
"width": self.image_size,
"height": self.image_size,
"visualization_width": self.visualization_width,
"visualization_height": self.visualization_height,
}
if self.from_pixels
else {}
)
self._env = TASKS[self.task]["env"](**kwargs)
self._env = TASKS[self.task]["env"]()
num_actions = len(TASKS[self.task]["action_space"])
self._action_space = gymnasium.spaces.Box(low=-1.0, high=1.0, shape=(num_actions,))
@@ -88,12 +70,12 @@ class SimxarmEnv(AbstractEnv):
if "w" not in TASKS[self.task]["action_space"]:
self._action_padding[-1] = 1.0
def render(self, mode="rgb_array"):
return self._env.render(mode)
def render(self, mode="rgb_array", width=384, height=384):
return self._env.render(mode, width=width, height=height)
def _format_raw_obs(self, raw_obs):
if self.from_pixels:
image = self.render(mode="rgb_array")
image = self.render(mode="rgb_array", width=self.image_size, height=self.image_size)
image = image.transpose(2, 0, 1) # (H, W, C) -> (C, H, W)
image = torch.tensor(image.copy(), dtype=torch.uint8)

View File

@@ -1,9 +1,7 @@
# from copy import deepcopy
import os
import mujoco
import numpy as np
from gymnasium.envs.mujoco.mujoco_rendering import MujocoRenderer
from gymnasium_robotics.envs import robot_env
from lerobot.common.envs.simxarm.simxarm.tasks import mocap
@@ -17,29 +15,20 @@ class Base(robot_env.MujocoRobotEnv):
gripper_rotation (list): initial rotation of the gripper (given as a quaternion)
"""
def __init__(self, xml_name, gripper_rotation=None, **kwargs):
def __init__(self, xml_name, gripper_rotation=None):
if gripper_rotation is None:
gripper_rotation = [0, 1, 0, 0]
self.gripper_rotation = np.array(gripper_rotation, dtype=np.float32)
self.center_of_table = np.array([1.655, 0.3, 0.63625])
self.max_z = 1.2
self.min_z = 0.2
visualization_width = kwargs.pop("visualization_width") if "visualization_width" in kwargs else None
visualization_height = (
kwargs.pop("visualization_height") if "visualization_height" in kwargs else None
)
super().__init__(
model_path=os.path.join(os.path.dirname(__file__), "assets", xml_name + ".xml"),
n_substeps=20,
n_actions=4,
initial_qpos={},
**kwargs,
)
if visualization_width is not None and visualization_height is not None:
self._set_custom_size_renderer(width=visualization_width, height=visualization_height)
@property
def dt(self):
return self.n_substeps * self.model.opt.timestep
@@ -144,26 +133,12 @@ class Base(robot_env.MujocoRobotEnv):
info = {"is_success": self.is_success(), "success": self.is_success()}
return obs, reward, done, info
def render(self, mode="rgb_array"):
def render(self, mode="rgb_array", width=384, height=384):
self._render_callback()
if mode == "visualization":
return self._custom_size_render()
return self.mujoco_renderer.render(mode, camera_name="camera0")
def _set_custom_size_renderer(self, width, height):
from copy import deepcopy
# HACK
custom_render_model = deepcopy(self.model)
custom_render_model.vis.global_.offwidth = width
custom_render_model.vis.global_.offheight = height
self.custom_size_renderer = MujocoRenderer(custom_render_model, self.data)
del custom_render_model
def _custom_size_render(self):
return self.custom_size_renderer.render("rgb_array", camera_name="camera0")
self.model.vis.global_.offwidth = width
self.model.vis.global_.offheight = height
return self.mujoco_renderer.render(mode)
def close(self):
if self.mujoco_renderer is not None:

View File

@@ -4,9 +4,9 @@ from lerobot.common.envs.simxarm.simxarm import Base
class Lift(Base):
def __init__(self, **kwargs):
def __init__(self):
self._z_threshold = 0.15
super().__init__("lift", **kwargs)
super().__init__("lift")
@property
def z_target(self):

View File

@@ -5,6 +5,7 @@ from pathlib import Path
from omegaconf import OmegaConf
from termcolor import colored
from lerobot.common.policies.abstract import AbstractPolicy
def log_output_dir(out_dir):
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {out_dir}")
@@ -67,11 +68,11 @@ class Logger:
logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}")
self._wandb = wandb
def save_model(self, policy, identifier):
def save_model(self, policy: AbstractPolicy, identifier):
if self._save_model:
self._model_dir.mkdir(parents=True, exist_ok=True)
fp = self._model_dir / f"{str(identifier)}.pt"
policy.save(fp)
policy.save_pretrained(fp)
if self._wandb and not self._disable_wandb_artifact:
# note wandb artifact does not accept ":" in its name
artifact = self._wandb.Artifact(

View File

@@ -2,14 +2,25 @@ from collections import deque
import torch
from torch import Tensor, nn
from huggingface_hub import PyTorchModelHubMixin
class AbstractPolicy(nn.Module):
class AbstractPolicy(nn.Module, PyTorchModelHubMixin):
"""Base policy which all policies should be derived from.
The forward method should generally not be overriden as it plays the role of handling multi-step policies. See its
documentation for more information.
The policy is a PyTorchModelHubMixin, which means that it can be saved and loaded from the Hugging Face Hub and/or to a local directory.
# Save policy weights to local directory
>>> policy.save_pretrained("my-awesome-policy")
# Push policy weights to the Hub
>>> policy.push_to_hub("my-awesome-policy")
# Download and initialize policy from the Hub
>>> policy = MyPolicy.from_pretrained("username/my-awesome-policy")
Note:
When implementing a concrete class (e.g. `AlohaDataset`, `PushtEnv`, `DiffusionPolicy`), you need to:
1. set the required class attributes:
@@ -22,7 +33,7 @@ class AbstractPolicy(nn.Module):
name: str | None = None # same name should be used to instantiate the policy in factory.py
def __init__(self, n_action_steps: int | None):
def __init__(self, n_action_steps: int | None = None):
"""
n_action_steps: Sets the cache size for storing action trajectories. If None, it is assumed that a single
action is returned by `select_actions` and that doesn't have a horizon dimension. The `forward` method then
@@ -37,10 +48,10 @@ class AbstractPolicy(nn.Module):
"""One step of the policy's learning algorithm."""
raise NotImplementedError("Abstract method")
def save(self, fp):
def save(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp): # TODO: remove this method since we are using PyTorchModelHubMixin
d = torch.load(fp)
self.load_state_dict(d)

View File

@@ -136,8 +136,8 @@ class ActionChunkingTransformerPolicy(AbstractPolicy):
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
def load(self, fp, device=None):
d = torch.load(fp, map_location=device)
self.load_state_dict(d)
def compute_loss(self, batch):

View File

@@ -32,7 +32,7 @@ assert len(unexpected_keys) == 0
Then in that same runtime you can also save the weights with the new aligned state_dict:
```
policy.save("weights.pt")
policy.save_pretrained("my-policy")
```
Now you can remove the breakpoint and extra code and load in the weights just like with any other lerobot checkpoint.

View File

@@ -203,8 +203,8 @@ class DiffusionPolicy(AbstractPolicy):
def save(self, fp):
torch.save(self.state_dict(), fp)
def load(self, fp):
d = torch.load(fp)
def load(self, fp, device=None):
d = torch.load(fp, map_location=device)
missing_keys, unexpected_keys = self.load_state_dict(d, strict=False)
if len(missing_keys) > 0:
assert all(k.startswith("ema_diffusion.") for k in missing_keys)

View File

@@ -1,35 +1,53 @@
def make_policy(cfg):
""" Factory for policies
"""
from lerobot.common.policies.abstract import AbstractPolicy
def make_policy(cfg: dict) -> AbstractPolicy:
""" Instantiate a policy from the configuration.
Currently supports TD-MPC, Diffusion, and ACT: select the policy with cfg.policy.name: tdmpc, diffusion, act.
Args:
cfg: The configuration (DictConfig)
"""
policy_kwargs = {}
if cfg.policy.name != "diffusion" and cfg.rollout_batch_size > 1:
raise NotImplementedError("Only diffusion policy supports rollout_batch_size > 1 for the time being.")
if cfg.policy.name == "tdmpc":
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
policy = TDMPCPolicy(cfg.policy, cfg.device)
policy_cls = TDMPCPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device}
elif cfg.policy.name == "diffusion":
from lerobot.common.policies.diffusion.policy import DiffusionPolicy
policy = DiffusionPolicy(
cfg=cfg.policy,
cfg_device=cfg.device,
cfg_noise_scheduler=cfg.noise_scheduler,
cfg_rgb_model=cfg.rgb_model,
cfg_obs_encoder=cfg.obs_encoder,
cfg_optimizer=cfg.optimizer,
cfg_ema=cfg.ema,
n_action_steps=cfg.n_action_steps + cfg.n_latency_steps,
policy_cls = DiffusionPolicy
policy_kwargs = {
"cfg": cfg.policy,
"cfg_device": cfg.device,
"cfg_noise_scheduler": cfg.noise_scheduler,
"cfg_rgb_model": cfg.rgb_model,
"cfg_obs_encoder": cfg.obs_encoder,
"cfg_optimizer": cfg.optimizer,
"cfg_ema": cfg.ema,
"n_action_steps": cfg.n_action_steps + cfg.n_latency_steps,
**cfg.policy,
)
}
elif cfg.policy.name == "act":
from lerobot.common.policies.act.policy import ActionChunkingTransformerPolicy
policy = ActionChunkingTransformerPolicy(
cfg.policy, cfg.device, n_action_steps=cfg.n_action_steps + cfg.n_latency_steps
)
policy_cls = ActionChunkingTransformerPolicy
policy_kwargs = {"cfg": cfg.policy, "device": cfg.device, "n_action_steps": cfg.n_action_steps + cfg.n_latency_steps}
else:
raise ValueError(cfg.policy.name)
if cfg.policy.pretrained_model_path:
# policy.load(cfg.policy.pretrained_model_path, device=cfg.device)
policy = policy_cls.from_pretrained(cfg.policy.pretrained_model_path, map_location=cfg.device, **policy_kwargs)
# TODO(rcadene): hack for old pretrained models from fowm
if cfg.policy.name == "tdmpc" and "fowm" in cfg.policy.pretrained_model_path:
if "offline" in cfg.pretrained_model_path:
@@ -38,6 +56,5 @@ def make_policy(cfg):
policy.step[0] = 100000
else:
raise NotImplementedError()
policy.load(cfg.policy.pretrained_model_path)
return policy

View File

@@ -122,9 +122,9 @@ class TDMPCPolicy(AbstractPolicy):
"""Save state dict of TOLD model to filepath."""
torch.save(self.state_dict(), fp)
def load(self, fp):
def load(self, fp, device=None):
"""Load a saved state dict from filepath into current agent."""
d = torch.load(fp)
d = torch.load(fp, map_location=device)
self.model.load_state_dict(d["model"])
self.model_target.load_state_dict(d["model_target"])

View File

@@ -1,9 +1,13 @@
import logging
import os.path as osp
import random
from datetime import datetime
from pathlib import Path
import hydra
import numpy as np
import torch
from omegaconf import DictConfig
def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device:
@@ -63,3 +67,31 @@ def format_big_number(num):
num /= divisor
return num
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig:
"""Initialize a Hydra config given only the path to the relevant config file.
For config resolution, it is assumed that the config file's parent is the Hydra config dir.
"""
# TODO(alexander-soare): Resolve configs without Hydra initialization.
hydra.core.global_hydra.GlobalHydra.instance().clear()
# Hydra needs a path relative to this file.
hydra.initialize(
str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent))
)
cfg = hydra.compose(Path(config_path).stem, overrides)
return cfg

View File

@@ -20,8 +20,6 @@ env:
action_repeat: 2
episode_length: 25
fps: ${fps}
visualization_width: 400
visualization_height: 400
policy:
state_dim: 4

View File

@@ -30,14 +30,13 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10
import argparse
import json
import logging
import os.path as osp
import threading
import time
from typing import Tuple, Union
from datetime import datetime as dt
from pathlib import Path
import einops
import hydra
import imageio
import numpy as np
import torch
@@ -49,11 +48,10 @@ from torchrl.envs.batched_envs import BatchedEnvBase
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.logger import log_output_dir
from lerobot.common.policies.abstract import AbstractPolicy
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed
from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed
def write_video(video_path, stacked_frames, fps):
@@ -69,7 +67,19 @@ def eval_policy(
video_dir: Path = None,
fps: int = 15,
return_first_video: bool = False,
):
) -> Union[dict, Tuple[dict, torch.Tensor]]:
""" Evaluate a policy on an environment by running rollouts and computing metrics.
Args:
env: The environment to evaluate.
policy: The policy to evaluate.
num_episodes: The number of episodes to evaluate.
max_steps: The maximum number of steps per episode.
save_video: Whether to save videos of the evaluation episodes.
video_dir: The directory to save the videos.
fps: The frames per second for the videos.
return_first_video: Whether to return the first video as a tensor.
"""
if policy is not None:
policy.eval()
start = time.time()
@@ -87,10 +97,7 @@ def eval_policy(
def maybe_render_frame(env: EnvBase, _):
if save_video or (return_first_video and i == 0): # noqa: B023
# HACK
# TODO(aliberts): set render_mode for all envs
render_mode = "visualization" if isinstance(env, SimxarmEnv) else "rgb_array"
ep_frames.append(env.render(mode=render_mode)) # noqa: B023
ep_frames.append(env.render()) # noqa: B023
# Clear the policy's action queue before the start of a new rollout.
if policy is not None:
@@ -151,7 +158,7 @@ def eval_policy(
for thread in threads:
thread.join()
info = {
info = { # TODO: change to dataclass
"per_episode": [
{
"episode_ix": i,
@@ -184,6 +191,13 @@ def eval_policy(
def eval(cfg: dict, out_dir=None, stats_path=None):
""" Evaluate a policy.
Args:
cfg: The configuration (DictConfig).
out_dir: The directory to save the evaluation results (JSON file and videos)
stats_path: The path to the stats file.
"""
if out_dir is None:
raise NotImplementedError()
@@ -199,6 +213,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
log_output_dir(out_dir)
logging.info("Making transforms.")
# TODO(alexander-soare): Completely decouple datasets from evaluation.
offline_buffer = make_offline_buffer(cfg, stats_path=stats_path)
logging.info("Making environment.")
@@ -233,19 +248,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None):
logging.info("End of eval")
def _relative_path_between(path1: Path, path2: Path) -> Path:
"""Returns path1 relative to path2."""
path1 = path1.absolute()
path2 = path2.absolute()
try:
return path1.relative_to(path2)
except ValueError: # most likely because path1 is not a subpath of path2
common_parts = Path(osp.commonpath([path1, path2])).parts
return Path(
"/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :]))
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
@@ -263,19 +265,15 @@ if __name__ == "__main__":
if args.config is not None:
# Note: For the config_path, Hydra wants a path relative to this script file.
hydra.initialize(
config_path=str(
_relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent)
)
)
cfg = hydra.compose(Path(args.config).stem, args.overrides)
cfg = init_hydra_config(args.config, args.overrides)
# TODO(alexander-soare): Save and load stats in trained model directory.
stats_path = None
elif args.hub_id is not None:
folder = Path(snapshot_download(args.hub_id, revision=args.revision))
cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent)))
cfg = hydra.compose("config", args.overrides)
cfg.policy.pretrained_model_path = folder / "model.pt"
cfg = init_hydra_config(
folder / "config.yaml", [*args.overrides]
# folder / "config.yaml" # , [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides]
)
stats_path = folder / "stats.pth"
eval(

View File

@@ -2,8 +2,9 @@ import pytest
import torch
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, init_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
@@ -18,7 +19,10 @@ from .utils import DEVICE, init_config
],
)
def test_factory(env_name, dataset_id):
cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"])
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]
)
offline_buffer = make_offline_buffer(cfg)
for key in offline_buffer.image_keys:
img = offline_buffer[0].get(key)

View File

@@ -4,12 +4,13 @@ import torch
from torchrl.envs.utils import check_env_specs, step_mdp
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.pusht.env import PushtEnv
from lerobot.common.envs.simxarm.env import SimxarmEnv
from lerobot.common.envs.aloha.env import AlohaEnv
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, init_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
def print_spec_rollout(env):
@@ -110,7 +111,10 @@ def test_pusht(from_pixels, pixels_only):
],
)
def test_factory(env_name):
cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"])
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"device={DEVICE}"],
)
offline_buffer = make_offline_buffer(cfg)

View File

@@ -1,19 +1,70 @@
import pytest
from pathlib import Path
@pytest.mark.parametrize(
"path",
[
"examples/1_visualize_dataset.py",
"examples/2_evaluate_pretrained_policy.py",
"examples/3_train_policy.py",
],
)
def test_example(path):
with open(path, 'r') as file:
def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str:
for f, r in zip(finds, replaces):
assert f in text
text = text.replace(f, r)
return text
def test_example_1():
path = "examples/1_visualize_dataset.py"
with open(path, "r") as file:
file_contents = file.read()
exec(file_contents)
if path == "examples/1_visualize_dataset.py":
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists()
def test_examples_3_and_2():
"""
Train a model with example 3, check the outputs.
Evaluate the trained model with example 2, check the outputs.
"""
path = "examples/3_train_policy.py"
with open(path, "r") as file:
file_contents = file.read()
# Do less steps and use CPU.
file_contents = _find_and_replace(
file_contents,
['"offline_steps=5000"', '"device=cuda"'],
['"offline_steps=1"', '"device=cpu"'],
)
exec(file_contents)
for file_name in ["model.pt", "stats.pth", "config.yaml"]:
assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists()
path = "examples/2_evaluate_pretrained_policy.py"
with open(path, "r") as file:
file_contents = file.read()
# Do less evals, use CPU, and use the local model.
file_contents = _find_and_replace(
file_contents,
[
'"eval_episodes=10"',
'"rollout_batch_size=10"',
'"device=cuda"',
'# folder = Path("outputs/train/example_pusht_diffusion")',
'hub_id = "lerobot/diffusion_policy_pusht_image"',
"folder = Path(snapshot_download(hub_id)",
],
[
'"eval_episodes=1"',
'"rollout_batch_size=1"',
'"device=cpu"',
'folder = Path("outputs/train/example_pusht_diffusion")',
"",
"",
],
)
assert Path(f"outputs/train/example_pusht_diffusion").exists()

View File

@@ -1,4 +1,3 @@
from omegaconf import open_dict
import pytest
from tensordict import TensorDict
from tensordict.nn import TensorDictModule
@@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_offline_buffer
from lerobot.common.policies.abstract import AbstractPolicy
from .utils import DEVICE, init_config
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
@@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides):
- Updating the policy.
- Using the policy to select actions at inference time.
"""
cfg = init_config(
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[
f"env={env_name}",
f"policy={policy_name}",

View File

@@ -1,13 +1,6 @@
import os
import hydra
from hydra import compose, initialize
CONFIG_PATH = "../lerobot/configs"
# Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
def init_config(config_name="default", overrides=None):
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=CONFIG_PATH)
cfg = compose(config_name=config_name, overrides=overrides)
return cfg