From 1ed0110900db5d8db8cbf7757705c65025a61321 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 26 Mar 2024 16:13:40 +0000 Subject: [PATCH 1/5] finish examples 2 and 3 --- examples/2_evaluate_pretrained_policy.py | 40 ++++++++++++++- examples/3_train_policy.py | 56 ++++++++++++++++++++- lerobot/common/datasets/abstract.py | 4 +- lerobot/common/datasets/aloha.py | 2 + lerobot/common/datasets/factory.py | 6 +++ lerobot/common/datasets/pusht.py | 2 + lerobot/common/datasets/simxarm.py | 2 + lerobot/common/utils.py | 30 +++++++++++ lerobot/scripts/eval.py | 33 +++---------- tests/test_examples.py | 63 +++++++++++++++++++----- 10 files changed, 196 insertions(+), 42 deletions(-) diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 4640904..bb73167 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -1 +1,39 @@ -# TODO +""" +This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local +training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. +""" + +from pathlib import Path + +from huggingface_hub import snapshot_download + +from lerobot.common.utils import init_hydra_config +from lerobot.scripts.eval import eval + +# Get a pretrained policy from the hub. +hub_id = "lerobot/diffusion_policy_pusht_image" +folder = Path(snapshot_download(hub_id)) +# OR uncomment the following to evaluate a policy from the local outputs/train folder. +folder = Path("outputs/train/example_pusht_diffusion") + +config_path = folder / "config.yaml" +weights_path = folder / "model.pt" +stats_path = folder / "stats.pth" # normalization stats + +# Override some config parameters to do with evaluation. +overrides = [ + f"policy.pretrained_model_path={weights_path}", + "eval_episodes=10", + "rollout_batch_size=10", + "device=cuda", +] + +# Create a Hydra config. +cfg = init_hydra_config(config_path, overrides) + +# Evaluate the policy and save the outputs including metrics and videos. +eval( + cfg, + out_dir=f"outputs/eval/example_{cfg.env.name}_{cfg.policy.name}", + stats_path=stats_path, +) diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 4640904..01a4cf7 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -1 +1,55 @@ -# TODO +"""This scripts demonstrates how to train Diffusion Policy on the PushT environment. + +Once you have trained a model with this script, you can try to evaluate it on +examples/2_evaluate_pretrained_policy.py +""" + +import os +from pathlib import Path + +import torch +from omegaconf import OmegaConf +from tqdm import trange + +from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.policies.diffusion.policy import DiffusionPolicy +from lerobot.common.utils import init_hydra_config + +output_directory = Path("outputs/train/example_pusht_diffusion") +os.makedirs(output_directory, exist_ok=True) + +overrides = [ + "env=pusht", + "policy=diffusion", + # Adjust as you prefer. 5000 steps are needed to get something worth evaluating. + "offline_steps=5000", + "log_freq=250", + "device=cuda", +] + +cfg = init_hydra_config("lerobot/configs/default.yaml", overrides) + +policy = DiffusionPolicy( + cfg=cfg.policy, + cfg_device=cfg.device, + cfg_noise_scheduler=cfg.noise_scheduler, + cfg_rgb_model=cfg.rgb_model, + cfg_obs_encoder=cfg.obs_encoder, + cfg_optimizer=cfg.optimizer, + cfg_ema=cfg.ema, + n_action_steps=cfg.n_action_steps + cfg.n_latency_steps, + **cfg.policy, +) +policy.train() + +offline_buffer = make_offline_buffer(cfg) + +for offline_step in trange(cfg.offline_steps): + train_info = policy.update(offline_buffer, offline_step) + if offline_step % cfg.log_freq == 0: + print(train_info) + +# Save the policy, configuration, and normalization stats for later use. +policy.save(output_directory / "model.pt") +OmegaConf.save(cfg, output_directory / "config.yaml") +torch.save(offline_buffer.transform[-1].stats, output_directory / "stats.pth") diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index a81de49..c05d25c 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -59,6 +59,8 @@ class AbstractDataset(TensorDictReplayBuffer): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + # Don't actually load any data. This is a stand-in solution to get the transforms. + dummy: bool = False, ): assert ( self.available_datasets is not None @@ -77,7 +79,7 @@ class AbstractDataset(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() + storage = self._download_or_load_dataset() if not dummy else None super().__init__( storage=storage, diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 031c2cd..83d1581 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -97,6 +97,7 @@ class AlohaDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -110,6 +111,7 @@ class AlohaDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) @property diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4212e02..276dc76 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -21,7 +21,12 @@ def make_offline_buffer( overwrite_batch_size=None, overwrite_prefetch=None, stats_path=None, + # Don't actually load any data. This is a stand-in solution to get the transforms. + dummy=False, ): + if dummy and normalize and stats_path is None: + raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.") + if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -93,6 +98,7 @@ def make_offline_buffer( root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, + dummy=dummy, ) if cfg.policy.name == "tdmpc": diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index 624fb14..d167f3e 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -100,6 +100,7 @@ class PushtDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -113,6 +114,7 @@ class PushtDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index dc30e69..06931d3 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -51,6 +51,7 @@ class SimxarmDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, + dummy: bool = False, ): super().__init__( dataset_id, @@ -64,6 +65,7 @@ class SimxarmDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, + dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 2af1d96..86383cd 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -1,9 +1,13 @@ import logging +import os.path as osp import random from datetime import datetime +from pathlib import Path +import hydra import numpy as np import torch +from omegaconf import DictConfig def get_safe_torch_device(cfg_device: str, log: bool = False) -> torch.device: @@ -63,3 +67,29 @@ def format_big_number(num): num /= divisor return num + + +def _relative_path_between(path1: Path, path2: Path) -> Path: + """Returns path1 relative to path2.""" + path1 = path1.absolute() + path2 = path2.absolute() + try: + return path1.relative_to(path2) + except ValueError: # most likely because path1 is not a subpath of path2 + common_parts = Path(osp.commonpath([path1, path2])).parts + return Path( + "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + ) + + +def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig: + """Initialize a Hydra config given only the path to the relevant config file. + + For config resolution, it is assumed that the config file's parent is the Hydra config dir. + """ + # Hydra needs a path relative to this file. + hydra.initialize( + str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)) + ) + cfg = hydra.compose(Path(config_path).stem, overrides) + return cfg diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 1de0bb0..7251750 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -30,14 +30,12 @@ python lerobot/scripts/eval.py --hub-id HUB/ID --revision v1.0 eval_episodes=10 import argparse import json import logging -import os.path as osp import threading import time from datetime import datetime as dt from pathlib import Path import einops -import hydra import imageio import numpy as np import torch @@ -52,7 +50,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.logger import log_output_dir from lerobot.common.policies.abstract import AbstractPolicy from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_logging, set_global_seed +from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): @@ -195,7 +193,8 @@ def eval(cfg: dict, out_dir=None, stats_path=None): log_output_dir(out_dir) logging.info("Making transforms.") - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) + # TODO(alexander-soare): Completely decouple datasets from evaluation. + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) @@ -229,19 +228,6 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("End of eval") -def _relative_path_between(path1: Path, path2: Path) -> Path: - """Returns path1 relative to path2.""" - path1 = path1.absolute() - path2 = path2.absolute() - try: - return path1.relative_to(path2) - except ValueError: # most likely because path1 is not a subpath of path2 - common_parts = Path(osp.commonpath([path1, path2])).parts - return Path( - "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) - ) - - if __name__ == "__main__": parser = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter @@ -259,19 +245,14 @@ if __name__ == "__main__": if args.config is not None: # Note: For the config_path, Hydra wants a path relative to this script file. - hydra.initialize( - config_path=str( - _relative_path_between(Path(args.config).absolute().parent, Path(__file__).parent) - ) - ) - cfg = hydra.compose(Path(args.config).stem, args.overrides) + cfg = init_hydra_config(args.config, args.overrides) # TODO(alexander-soare): Save and load stats in trained model directory. stats_path = None elif args.hub_id is not None: folder = Path(snapshot_download(args.hub_id, revision="v1.0")) - cfg = hydra.initialize(config_path=str(_relative_path_between(folder, Path(__file__).parent))) - cfg = hydra.compose("config", args.overrides) - cfg.policy.pretrained_model_path = folder / "model.pt" + cfg = init_hydra_config( + folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] + ) stats_path = folder / "stats.pth" eval( diff --git a/tests/test_examples.py b/tests/test_examples.py index 6c21eb4..9da7a66 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,19 +1,56 @@ -import pytest from pathlib import Path -@pytest.mark.parametrize( - "path", - [ - "examples/1_visualize_dataset.py", - "examples/2_evaluate_pretrained_policy.py", - "examples/3_train_policy.py", - ], -) -def test_example(path): - with open(path, 'r') as file: +def _find_and_replace(text: str, finds: list[str], replaces: list[str]) -> str: + for f, r in zip(finds, replaces): + assert f in text + text = text.replace(f, r) + return text + + +def test_example_1(): + path = "examples/1_visualize_dataset.py" + + with open(path, "r") as file: file_contents = file.read() exec(file_contents) - if path == "examples/1_visualize_dataset.py": - assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + assert Path("outputs/visualize_dataset/example/episode_0.mp4").exists() + + +def test_examples_3_and_2(): + """ + Train a model with example 3, check the outputs. + Evaluate the trained model with example 2, check the outputs. + """ + + path = "examples/3_train_policy.py" + + with open(path, "r") as file: + file_contents = file.read() + + # Do less steps and use CPU. + file_contents = _find_and_replace( + file_contents, + ['"offline_steps=5000"', '"device=cuda"'], + ['"offline_steps=1"', '"device=cpu"'], + ) + + exec(file_contents) + + for file_name in ["model.pt", "stats.pth", "config.yaml"]: + assert Path(f"outputs/train/example_pusht_diffusion/{file_name}").exists() + + path = "examples/2_evaluate_pretrained_policy.py" + + with open(path, "r") as file: + file_contents = file.read() + + # Do less evals and use CPU. + file_contents = _find_and_replace( + file_contents, + ['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'], + ['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'], + ) + + assert Path(f"outputs/train/example_pusht_diffusion").exists() \ No newline at end of file From be4441c7ff423435abae2d08edb35a7c53df87f3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 26 Mar 2024 16:28:16 +0000 Subject: [PATCH 2/5] update README --- README.md | 22 ++++------------------ examples/2_evaluate_pretrained_policy.py | 2 +- tests/test_examples.py | 22 ++++++++++++++++++---- 3 files changed, 23 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index 3591186..d71fad7 100644 --- a/README.md +++ b/README.md @@ -135,11 +135,7 @@ hydra.run.dir=outputs/visualize_dataset/example ### Evaluate a pretrained policy -You can import our environment class, download pretrained policies from the HuggingFace hub, and use our rollout utilities with rendering: -```python -""" Copy pasted from `examples/2_evaluate_pretrained_policy.py` -# TODO -``` +Check out [example 2](./examples/2_evaluate_pretrained_policy.py) to see how you can load a pretrained policy from HuggingFace hub, load up the corresponding environment and model, and run an evaluation. Or you can achieve the same result by executing our script from the command line: ```bash @@ -150,7 +146,7 @@ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_hub ``` -After launching training of your own policy, you can also re-evaluate the checkpoints with: +After training your own policy, you can also re-evaluate the checkpoints with: ```bash python lerobot/scripts/eval.py \ --config PATH/TO/FOLDER/config.yaml \ @@ -163,19 +159,9 @@ See `python lerobot/scripts/eval.py --help` for more instructions. ### Train your own policy -You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): -```python -""" Copy pasted from `examples/3_train_policy.py` -# TODO -``` +You can import our dataset, environment, policy classes, and use our training utilities (if some data is missing, it will be automatically downloaded from HuggingFace hub): check out [example 3](./examples/3_train_policy.py). After you run this, you may want to revisit [example 2](./examples/2_evaluate_pretrained_policy.py) to evaluate your training output! -Or you can achieve the same result by executing our script from the command line: -```bash -python lerobot/scripts/train.py \ -hydra.run.dir=outputs/train/example -``` - -You can easily train any policy on any environment: +In general, you can use our training script to easily train any policy on any environment: ```bash python lerobot/scripts/train.py \ env=aloha \ diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index bb73167..be6abd1 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -14,7 +14,7 @@ from lerobot.scripts.eval import eval hub_id = "lerobot/diffusion_policy_pusht_image" folder = Path(snapshot_download(hub_id)) # OR uncomment the following to evaluate a policy from the local outputs/train folder. -folder = Path("outputs/train/example_pusht_diffusion") +# folder = Path("outputs/train/example_pusht_diffusion") config_path = folder / "config.yaml" weights_path = folder / "model.pt" diff --git a/tests/test_examples.py b/tests/test_examples.py index 9da7a66..4263e45 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -46,11 +46,25 @@ def test_examples_3_and_2(): with open(path, "r") as file: file_contents = file.read() - # Do less evals and use CPU. + # Do less evals, use CPU, and use the local model. file_contents = _find_and_replace( file_contents, - ['"eval_episodes=10"', '"rollout_batch_size=10"', '"device=cuda"'], - ['"eval_episodes=1"', '"rollout_batch_size=1"','"device=cpu"'], + [ + '"eval_episodes=10"', + '"rollout_batch_size=10"', + '"device=cuda"', + '# folder = Path("outputs/train/example_pusht_diffusion")', + 'hub_id = "lerobot/diffusion_policy_pusht_image"', + "folder = Path(snapshot_download(hub_id)", + ], + [ + '"eval_episodes=1"', + '"rollout_batch_size=1"', + '"device=cpu"', + 'folder = Path("outputs/train/example_pusht_diffusion")', + "", + "", + ], ) - assert Path(f"outputs/train/example_pusht_diffusion").exists() \ No newline at end of file + assert Path(f"outputs/train/example_pusht_diffusion").exists() From 011f2d27febf57686fe5143a12ff6798a40e38c4 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 26 Mar 2024 16:40:54 +0000 Subject: [PATCH 3/5] fix tests --- lerobot/common/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lerobot/common/utils.py b/lerobot/common/utils.py index 86383cd..7ed2933 100644 --- a/lerobot/common/utils.py +++ b/lerobot/common/utils.py @@ -87,6 +87,8 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D For config resolution, it is assumed that the config file's parent is the Hydra config dir. """ + # TODO(alexander-soare): Resolve configs without Hydra initialization. + hydra.core.global_hydra.GlobalHydra.instance().clear() # Hydra needs a path relative to this file. hydra.initialize( str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)) From 6cd671040fbff2c49778176aae263b27a4d943db Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 27 Mar 2024 13:22:14 +0000 Subject: [PATCH 4/5] fix revision --- README.md | 1 - lerobot/scripts/eval.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index d71fad7..0786c6d 100644 --- a/README.md +++ b/README.md @@ -141,7 +141,6 @@ Or you can achieve the same result by executing our script from the command line ```bash python lerobot/scripts/eval.py \ --hub-id lerobot/diffusion_policy_pusht_image \ ---revision v1.0 \ eval_episodes=10 \ hydra.run.dir=outputs/eval/example_hub ``` diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 7251750..2a3ab13 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -249,7 +249,7 @@ if __name__ == "__main__": # TODO(alexander-soare): Save and load stats in trained model directory. stats_path = None elif args.hub_id is not None: - folder = Path(snapshot_download(args.hub_id, revision="v1.0")) + folder = Path(snapshot_download(args.hub_id, revision=args.revision)) cfg = init_hydra_config( folder / "config.yaml", [f"policy.pretrained_model_path={folder / 'model.pt'}", *args.overrides] ) From b7c9c330725450d86ef24957c96d7710cf2edaee Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 27 Mar 2024 18:33:48 +0000 Subject: [PATCH 5/5] revision --- lerobot/common/datasets/abstract.py | 4 +--- lerobot/common/datasets/aloha.py | 2 -- lerobot/common/datasets/factory.py | 6 ------ lerobot/common/datasets/pusht.py | 2 -- lerobot/common/datasets/simxarm.py | 2 -- lerobot/scripts/eval.py | 2 +- tests/test_datasets.py | 8 ++++++-- tests/test_envs.py | 8 ++++++-- tests/test_policies.py | 8 ++++---- tests/utils.py | 11 ++--------- 10 files changed, 20 insertions(+), 33 deletions(-) diff --git a/lerobot/common/datasets/abstract.py b/lerobot/common/datasets/abstract.py index c05d25c..a81de49 100644 --- a/lerobot/common/datasets/abstract.py +++ b/lerobot/common/datasets/abstract.py @@ -59,8 +59,6 @@ class AbstractDataset(TensorDictReplayBuffer): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - # Don't actually load any data. This is a stand-in solution to get the transforms. - dummy: bool = False, ): assert ( self.available_datasets is not None @@ -79,7 +77,7 @@ class AbstractDataset(TensorDictReplayBuffer): f"The version of the dataset ({self.version}) is not enforced when root is provided ({self.root})." ) - storage = self._download_or_load_dataset() if not dummy else None + storage = self._download_or_load_dataset() super().__init__( storage=storage, diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 83d1581..031c2cd 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -97,7 +97,6 @@ class AlohaDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -111,7 +110,6 @@ class AlohaDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) @property diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 4e02f70..0407703 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -21,12 +21,7 @@ def make_offline_buffer( overwrite_batch_size=None, overwrite_prefetch=None, stats_path=None, - # Don't actually load any data. This is a stand-in solution to get the transforms. - dummy=False, ): - if dummy and normalize and stats_path is None: - raise ValueError("`stats_path` is required if `dummy` and `normalize` are True.") - if cfg.policy.balanced_sampling: assert cfg.online_steps > 0 batch_size = None @@ -93,7 +88,6 @@ def make_offline_buffer( root=DATA_DIR, pin_memory=pin_memory, prefetch=prefetch if isinstance(prefetch, int) else None, - dummy=dummy, ) if cfg.policy.name == "tdmpc": diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index d167f3e..624fb14 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -100,7 +100,6 @@ class PushtDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -114,7 +113,6 @@ class PushtDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/common/datasets/simxarm.py b/lerobot/common/datasets/simxarm.py index 06931d3..dc30e69 100644 --- a/lerobot/common/datasets/simxarm.py +++ b/lerobot/common/datasets/simxarm.py @@ -51,7 +51,6 @@ class SimxarmDataset(AbstractDataset): collate_fn: Callable | None = None, writer: Writer | None = None, transform: "torchrl.envs.Transform" = None, - dummy: bool = False, ): super().__init__( dataset_id, @@ -65,7 +64,6 @@ class SimxarmDataset(AbstractDataset): collate_fn=collate_fn, writer=writer, transform=transform, - dummy=dummy, ) def _download_and_preproc_obsolete(self): diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 2a3ab13..216769d 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -194,7 +194,7 @@ def eval(cfg: dict, out_dir=None, stats_path=None): logging.info("Making transforms.") # TODO(alexander-soare): Completely decouple datasets from evaluation. - offline_buffer = make_offline_buffer(cfg, stats_path=stats_path, dummy=stats_path is not None) + offline_buffer = make_offline_buffer(cfg, stats_path=stats_path) logging.info("Making environment.") env = make_env(cfg, transform=offline_buffer.transform) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 252e004..adaefcf 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -2,8 +2,9 @@ import pytest import torch from lerobot.common.datasets.factory import make_offline_buffer +from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, init_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( @@ -18,7 +19,10 @@ from .utils import DEVICE, init_config ], ) def test_factory(env_name, dataset_id): - cfg = init_config(overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"]) + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[f"env={env_name}", f"env.task={dataset_id}", f"device={DEVICE}"] + ) offline_buffer = make_offline_buffer(cfg) for key in offline_buffer.image_keys: img = offline_buffer[0].get(key) diff --git a/tests/test_envs.py b/tests/test_envs.py index 2beafbd..2bd5e65 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,8 +7,9 @@ from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.envs.factory import make_env from lerobot.common.envs.pusht.env import PushtEnv from lerobot.common.envs.simxarm.env import SimxarmEnv +from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, init_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH def print_spec_rollout(env): @@ -89,7 +90,10 @@ def test_pusht(from_pixels, pixels_only): ], ) def test_factory(env_name): - cfg = init_config(overrides=[f"env={env_name}", f"device={DEVICE}"]) + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, + overrides=[f"env={env_name}", f"device={DEVICE}"], + ) offline_buffer = make_offline_buffer(cfg) diff --git a/tests/test_policies.py b/tests/test_policies.py index d3dc0bc..5d6b46d 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,4 +1,3 @@ -from omegaconf import open_dict import pytest from tensordict import TensorDict from tensordict.nn import TensorDictModule @@ -10,8 +9,8 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.envs.factory import make_env from lerobot.common.datasets.factory import make_offline_buffer from lerobot.common.policies.abstract import AbstractPolicy - -from .utils import DEVICE, init_config +from lerobot.common.utils import init_hydra_config +from .utils import DEVICE, DEFAULT_CONFIG_PATH @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", @@ -34,7 +33,8 @@ def test_concrete_policy(env_name, policy_name, extra_overrides): - Updating the policy. - Using the policy to select actions at inference time. """ - cfg = init_config( + cfg = init_hydra_config( + DEFAULT_CONFIG_PATH, overrides=[ f"env={env_name}", f"policy={policy_name}", diff --git a/tests/utils.py b/tests/utils.py index 5570933..6169c3b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,13 +1,6 @@ import os -import hydra -from hydra import compose, initialize -CONFIG_PATH = "../lerobot/configs" +# Pass this as the first argument to init_hydra_config. +DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") - -def init_config(config_name="default", overrides=None): - hydra.core.global_hydra.GlobalHydra.instance().clear() - initialize(config_path=CONFIG_PATH) - cfg = compose(config_name=config_name, overrides=overrides) - return cfg