From 7ad1909641580cef1661fa3d579aa29de4ec9c0d Mon Sep 17 00:00:00 2001 From: Simon Alibert <75076266+aliberts@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:47:42 +0200 Subject: [PATCH] Tests cleaning & simplification (#81) --- .github/ISSUE_TEMPLATE/bug-report.yml | 2 +- .github/workflows/test.yml | 4 +- .pre-commit-config.yaml | 2 +- CONTRIBUTING.md | 20 ++++ examples/3_evaluate_pretrained_policy.py | 2 +- examples/4_train_policy.py | 2 +- lerobot/__init__.py | 37 ++++-- lerobot/common/datasets/aloha.py | 1 + lerobot/common/datasets/pusht.py | 1 + lerobot/common/datasets/xarm.py | 5 +- lerobot/common/policies/factory.py | 2 +- lerobot/common/policies/tdmpc/policy.py | 2 +- lerobot/common/utils/import_utils.py | 44 +++++++ lerobot/common/{ => utils}/utils.py | 0 .../env.py => scripts/display_sys_info.py} | 4 +- lerobot/scripts/eval.py | 2 +- lerobot/scripts/train.py | 2 +- lerobot/scripts/visualize_dataset.py | 2 +- tests/test_available.py | 81 +++++++------ tests/test_datasets.py | 111 ++++++++++-------- tests/test_envs.py | 50 +++----- tests/test_examples.py | 10 +- tests/test_policies.py | 13 +- tests/utils.py | 35 +++++- 24 files changed, 277 insertions(+), 157 deletions(-) create mode 100644 lerobot/common/utils/import_utils.py rename lerobot/common/{ => utils}/utils.py (100%) rename lerobot/{commands/env.py => scripts/display_sys_info.py} (96%) diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 132c21c..7cbed67 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -11,7 +11,7 @@ body: id: system-info attributes: label: System Info - description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.commands.env` and copy-pasting its outputs below + description: If needed, you can share your lerobot configuration with us by running `python -m lerobot.scripts.display_sys_info` and copy-pasting its outputs below render: Shell placeholder: lerobot version, OS, python version, numpy version, torch version, and lerobot's configuration validations: diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a86193b..f0a7e78 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -117,11 +117,9 @@ jobs: # run tests & coverage #---------------------------------------------- - name: Run tests - env: - LEROBOT_TESTS_DEVICE: cpu run: | source .venv/bin/activate - pytest --cov=./lerobot --cov-report=xml tests + pytest -v --cov=./lerobot --cov-report=xml tests # TODO(aliberts): Link with HF Codecov account # - name: Upload coverage reports to Codecov with GitHub Action diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d0fb55..4a28292 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,4 +1,4 @@ -exclude: ^(data/|tests/) +exclude: ^(data/|tests/data) default_language_version: python: python3.10 repos: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0b40d81..5b69a13 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,6 +65,26 @@ A good feature request addresses the following points: If your issue is well written we're already 80% of the way there by the time you post it. +## Adding new policies, datasets or environments + +Look at our implementations for [datasets](./lerobot/common/datasets/), [policies](./lerobot/common/policies/), +environments ([aloha](https://github.com/huggingface/gym-aloha), +[xarm](https://github.com/huggingface/gym-xarm), +[pusht](https://github.com/huggingface/gym-pusht)) +and follow the same api design. + +When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps: +- Update `available_datasets` in `lerobot/__init__.py` +- Copy it in the required `available_datasets` class attribute + +When implementing a new environment (e.g. `gym_aloha`), follow these steps: +- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py` + +When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: +- Update `available_policies` in `lerobot/__init__.py` +- Set the required `name` class attribute. +- Update variables in `tests/test_available.py` by importing your new Policy class + ## Submitting a pull request (PR) Before writing code, we strongly advise you to search through the existing PRs or diff --git a/examples/3_evaluate_pretrained_policy.py b/examples/3_evaluate_pretrained_policy.py index b3d13f7..a892fa2 100644 --- a/examples/3_evaluate_pretrained_policy.py +++ b/examples/3_evaluate_pretrained_policy.py @@ -7,7 +7,7 @@ from pathlib import Path from huggingface_hub import snapshot_download -from lerobot.common.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config from lerobot.scripts.eval import eval # Get a pretrained policy from the hub. diff --git a/examples/4_train_policy.py b/examples/4_train_policy.py index 7a7a7aa..1ccb40d 100644 --- a/examples/4_train_policy.py +++ b/examples/4_train_policy.py @@ -13,7 +13,7 @@ from omegaconf import OmegaConf from lerobot.common.datasets.factory import make_dataset from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy -from lerobot.common.utils import init_hydra_config +from lerobot.common.utils.utils import init_hydra_config output_directory = Path("outputs/train/example_pusht_diffusion") os.makedirs(output_directory, exist_ok=True) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 8ab95df..83e51c7 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -7,16 +7,22 @@ Example: import lerobot print(lerobot.available_envs) print(lerobot.available_tasks_per_env) - print(lerobot.available_datasets_per_env) print(lerobot.available_datasets) print(lerobot.available_policies) + print(lerobot.available_policies_per_env) ``` -When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: -- Set the required class attributes: `available_datasets`. -- Set the required class attributes: `name`. -- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) -- Update variables in `tests/test_available.py` by importing your new class +When implementing a new dataset class (e.g. `AlohaDataset`) follow these steps: +- Update `available_datasets` in `lerobot/__init__.py` +- Set the required `available_datasets` class attribute using the previously updated `lerobot.available_datasets` + +When implementing a new environment (e.g. `gym_aloha`), follow these steps: +- Update `available_envs`, `available_tasks_per_env` and `available_datasets` in `lerobot/__init__.py` + +When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps: +- Update `available_policies` in `lerobot/__init__.py` +- Set the required `name` class attribute. +- Update variables in `tests/test_available.py` by importing your new Policy class """ from lerobot.__version__ import __version__ # noqa: F401 @@ -36,7 +42,7 @@ available_tasks_per_env = { "xarm": ["XarmLift-v0"], } -available_datasets_per_env = { +available_datasets = { "aloha": [ "aloha_sim_insertion_human", "aloha_sim_insertion_scripted", @@ -47,10 +53,23 @@ available_datasets_per_env = { "xarm": ["xarm_lift_medium"], } -available_datasets = [dataset for env in available_envs for dataset in available_datasets_per_env[env]] - available_policies = [ "act", "diffusion", "tdmpc", ] + +available_policies_per_env = { + "aloha": ["act"], + "pusht": ["diffusion"], + "xarm": ["tdmpc"], +} + +env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] +env_dataset_pairs = [(env, dataset) for env, datasets in available_datasets.items() for dataset in datasets] +env_dataset_policy_triplets = [ + (env, dataset, policy) + for env, datasets in available_datasets.items() + for dataset in datasets + for policy in available_policies_per_env[env] +] diff --git a/lerobot/common/datasets/aloha.py b/lerobot/common/datasets/aloha.py index 87ee57a..4769a2b 100644 --- a/lerobot/common/datasets/aloha.py +++ b/lerobot/common/datasets/aloha.py @@ -14,6 +14,7 @@ class AlohaDataset(torch.utils.data.Dataset): https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_scripted """ + # Copied from lerobot/__init__.py available_datasets = [ "aloha_sim_insertion_human", "aloha_sim_insertion_scripted", diff --git a/lerobot/common/datasets/pusht.py b/lerobot/common/datasets/pusht.py index b9d06ba..c5c06bf 100644 --- a/lerobot/common/datasets/pusht.py +++ b/lerobot/common/datasets/pusht.py @@ -17,6 +17,7 @@ class PushtDataset(torch.utils.data.Dataset): If `None`, no shift is applied to current timestamp and the data from the current frame is loaded. """ + # Copied from lerobot/__init__.py available_datasets = ["pusht"] fps = 10 image_keys = ["observation.image"] diff --git a/lerobot/common/datasets/xarm.py b/lerobot/common/datasets/xarm.py index 28ef4fa..711ff64 100644 --- a/lerobot/common/datasets/xarm.py +++ b/lerobot/common/datasets/xarm.py @@ -11,9 +11,8 @@ class XarmDataset(torch.utils.data.Dataset): https://huggingface.co/datasets/lerobot/xarm_lift_medium """ - available_datasets = [ - "xarm_lift_medium", - ] + # Copied from lerobot/__init__.py + available_datasets = ["xarm_lift_medium"] fps = 15 image_keys = ["observation.image"] diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index b5b5f86..9698175 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -2,7 +2,7 @@ import inspect from omegaconf import DictConfig, OmegaConf -from lerobot.common.utils import get_safe_torch_device +from lerobot.common.utils.utils import get_safe_torch_device def _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg): diff --git a/lerobot/common/policies/tdmpc/policy.py b/lerobot/common/policies/tdmpc/policy.py index ed28c4a..adaa30c 100644 --- a/lerobot/common/policies/tdmpc/policy.py +++ b/lerobot/common/policies/tdmpc/policy.py @@ -11,7 +11,7 @@ import torch.nn as nn import lerobot.common.policies.tdmpc.helper as h from lerobot.common.policies.utils import populate_queues -from lerobot.common.utils import get_safe_torch_device +from lerobot.common.utils.utils import get_safe_torch_device FIRST_FRAME = 0 diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py new file mode 100644 index 0000000..642e0ff --- /dev/null +++ b/lerobot/common/utils/import_utils.py @@ -0,0 +1,44 @@ +import importlib +import logging + + +def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: + """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py + Check if the package spec exists and grab its version to avoid importing a local directory. + **Note:** this doesn't work for all packages. + """ + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logging.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +_torch_available, _torch_version = is_package_available("torch", return_version=True) +_gym_xarm_available = is_package_available("gym_xarm") +_gym_aloha_available = is_package_available("gym_aloha") +_gym_pusht_available = is_package_available("gym_pusht") diff --git a/lerobot/common/utils.py b/lerobot/common/utils/utils.py similarity index 100% rename from lerobot/common/utils.py rename to lerobot/common/utils/utils.py diff --git a/lerobot/commands/env.py b/lerobot/scripts/display_sys_info.py similarity index 96% rename from lerobot/commands/env.py rename to lerobot/scripts/display_sys_info.py index 1a7e950..e4ea426 100644 --- a/lerobot/commands/env.py +++ b/lerobot/scripts/display_sys_info.py @@ -15,7 +15,7 @@ cuda_version = torch._C._cuda_getCompiledVersion() if torch.version.cuda is not # TODO(aliberts): refactor into an actual command `lerobot env` -def get_env_info() -> dict: +def display_sys_info() -> dict: """Run this to get basic system info to help for tracking issues & bugs.""" info = { "`lerobot` version": version, @@ -40,4 +40,4 @@ def format_dict(d: dict) -> str: if __name__ == "__main__": - get_env_info() + display_sys_info() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index bf53322..c76d4b4 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -50,7 +50,7 @@ from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.logger import log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed +from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, init_logging, set_global_seed def write_video(video_path, stacked_frames, fps): diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index b1d6306..601f730 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -13,7 +13,7 @@ from lerobot.common.datasets.utils import cycle from lerobot.common.envs.factory import make_env from lerobot.common.logger import Logger, log_output_dir from lerobot.common.policies.factory import make_policy -from lerobot.common.utils import ( +from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, init_logging, diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index 739115e..226fdc1 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -9,7 +9,7 @@ import torch from lerobot.common.datasets.factory import make_dataset from lerobot.common.logger import log_output_dir -from lerobot.common.utils import init_logging +from lerobot.common.utils.utils import init_logging NUM_EPISODES_TO_RENDER = 50 MAX_NUM_STEPS = 1000 diff --git a/tests/test_available.py b/tests/test_available.py index 373cc1a..4328ec6 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -1,53 +1,60 @@ -""" -This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully -imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid. - -When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps: -- Set the required class attributes: `available_datasets`. -- Set the required class attributes: `name`. -- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`) -- Update variables in `tests/test_available.py` by importing your new class -""" - import importlib -import pytest -import lerobot -import gymnasium as gym -from lerobot.common.datasets.xarm import XarmDataset +import gymnasium as gym +import pytest + +import lerobot from lerobot.common.datasets.aloha import AlohaDataset from lerobot.common.datasets.pusht import PushtDataset - +from lerobot.common.datasets.xarm import XarmDataset from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.common.policies.tdmpc.policy import TDMPCPolicy +from tests.utils import require_env -def test_available(): +@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs) +@require_env +def test_available_env_task(env_name: str, task_name: list): + """ + This test verifies that all environments listed in `lerobot/__init__.py` can + be sucessfully imported — if they're installed — and that their + `available_tasks_per_env` are valid. + """ + package_name = f"gym_{env_name}" + importlib.import_module(package_name) + gym_handle = f"{package_name}/{task_name}" + assert gym_handle in gym.envs.registry, gym_handle + + +@pytest.mark.parametrize( + "env_name, dataset_class", + [ + ("aloha", AlohaDataset), + ("pusht", PushtDataset), + ("xarm", XarmDataset), + ], +) +def test_available_datasets(env_name, dataset_class): + """ + This test verifies that the class attribute `available_datasets` for all + dataset classes is consistent with those listed in `lerobot/__init__.py`. + """ + available_env_datasets = lerobot.available_datasets[env_name] + assert set(available_env_datasets) == set( + dataset_class.available_datasets + ), f"{env_name=} {available_env_datasets=}" + + +def test_available_policies(): + """ + This test verifies that the class attribute `name` for all policies is + consistent with those listed in `lerobot/__init__.py`. + """ policy_classes = [ ActionChunkingTransformerPolicy, DiffusionPolicy, TDMPCPolicy, ] - - dataset_class_per_env = { - "aloha": AlohaDataset, - "pusht": PushtDataset, - "xarm": XarmDataset, - } - policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies - - for env_name in lerobot.available_envs: - for task_name in lerobot.available_tasks_per_env[env_name]: - package_name = f"gym_{env_name}" - importlib.import_module(package_name) - gym_handle = f"{package_name}/{task_name}" - assert gym_handle in gym.envs.registry.keys(), gym_handle - - dataset_class = dataset_class_per_env[env_name] - available_datasets = lerobot.available_datasets_per_env[env_name] - assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}" - - diff --git a/tests/test_datasets.py b/tests/test_datasets.py index f22eb99..e488c30 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1,33 +1,35 @@ +import logging import os from pathlib import Path + import einops import pytest import torch - -from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_previous_and_future_frames -from lerobot.common.transforms import Prod -from lerobot.common.utils import init_hydra_config -import logging -from lerobot.common.datasets.factory import make_dataset from datasets import Dataset -from .utils import DEVICE, DEFAULT_CONFIG_PATH - -@pytest.mark.parametrize( - "env_name,dataset_id,policy_name", - [ - ("xarm", "xarm_lift_medium", "tdmpc"), - ("pusht", "pusht", "diffusion"), - ("aloha", "aloha_sim_insertion_human", "act"), - ("aloha", "aloha_sim_insertion_scripted", "act"), - ("aloha", "aloha_sim_transfer_cube_human", "act"), - ("aloha", "aloha_sim_transfer_cube_scripted", "act"), - ], +import lerobot +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import ( + compute_stats, + get_stats_einops_patterns, + load_previous_and_future_frames, ) +from lerobot.common.transforms import Prod +from lerobot.common.utils.utils import init_hydra_config + +from .utils import DEFAULT_CONFIG_PATH, DEVICE + + +@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets) def test_factory(env_name, dataset_id, policy_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, - overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"] + overrides=[ + f"env={env_name}", + f"dataset_id={dataset_id}", + f"policy={policy_name}", + f"device={DEVICE}", + ], ) dataset = make_dataset(cfg) delta_timestamps = dataset.delta_timestamps @@ -51,7 +53,7 @@ def test_factory(env_name, dataset_id, policy_name): (key, 3, True), ) assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}" - + # test number of dimensions for key, ndim, required in keys_ndim_required: if key not in item: @@ -60,13 +62,13 @@ def test_factory(env_name, dataset_id, policy_name): else: logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.') continue - + if delta_timestamps is not None and key in delta_timestamps: assert item[key].ndim == ndim + 1, f"{key}" assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}" else: assert item[key].ndim == ndim, f"{key}" - + if key in image_keys: assert item[key].dtype == torch.float32, f"{key}" # TODO(rcadene): we assume for now that image normalization takes place in the model @@ -77,17 +79,16 @@ def test_factory(env_name, dataset_id, policy_name): # test t,c,h,w assert item[key].shape[1] == 3, f"{key}" else: - # test c,h,w + # test c,h,w assert item[key].shape[0] == 3, f"{key}" - if delta_timestamps is not None: # test missing keys in delta_timestamps for key in delta_timestamps: assert key in item, f"{key}" -def test_compute_stats(): +def test_compute_stats_on_xarm(): """Check that the statistics are computed correctly according to the stats_patterns property. We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do @@ -95,20 +96,20 @@ def test_compute_stats(): """ from lerobot.common.datasets.xarm import XarmDataset - DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None + data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None # get transform to convert images from uint8 [0,255] to float32 [0,1] transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0) dataset = XarmDataset( dataset_id="xarm_lift_medium", - root=DATA_DIR, + root=data_dir, transform=transform, ) # Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched # computation of the statistics. While doing this, we also make sure it works when we don't divide the - # dataset into even batches. + # dataset into even batches. computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25)) # get einops patterns to aggregate batches and compute statistics @@ -128,7 +129,9 @@ def test_compute_stats(): for k, pattern in stats_patterns.items(): expected_stats[k] = {} expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean") - expected_stats[k]["std"] = torch.sqrt(einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")) + expected_stats[k]["std"] = torch.sqrt( + einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean") + ) expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min") expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max") @@ -153,12 +156,14 @@ def test_compute_stats(): def test_load_previous_and_future_frames_within_tolerance(): - hf_dataset = Dataset.from_dict({ - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], - "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], - }) + hf_dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], + "index": [0, 1, 2, 3, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [5, 5, 5, 5, 5], + } + ) hf_dataset = hf_dataset.with_format("torch") item = hf_dataset[2] delta_timestamps = {"index": [-0.2, 0, 0.139]} @@ -168,13 +173,16 @@ def test_load_previous_and_future_frames_within_tolerance(): assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values" assert not is_pad.any(), "Unexpected padding detected" + def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(): - hf_dataset = Dataset.from_dict({ - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], - "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], - }) + hf_dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], + "index": [0, 1, 2, 3, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [5, 5, 5, 5, 5], + } + ) hf_dataset = hf_dataset.with_format("torch") item = hf_dataset[2] delta_timestamps = {"index": [-0.2, 0, 0.141]} @@ -182,13 +190,16 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range( with pytest.raises(AssertionError): load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) + def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range(): - hf_dataset = Dataset.from_dict({ - "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], - "index": [0, 1, 2, 3, 4], - "episode_data_index_from": [0, 0, 0, 0, 0], - "episode_data_index_to": [5, 5, 5, 5, 5], - }) + hf_dataset = Dataset.from_dict( + { + "timestamp": [0.1, 0.2, 0.3, 0.4, 0.5], + "index": [0, 1, 2, 3, 4], + "episode_data_index_from": [0, 0, 0, 0, 0], + "episode_data_index_to": [5, 5, 5, 5, 5], + } + ) hf_dataset = hf_dataset.with_format("torch") item = hf_dataset[2] delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]} @@ -196,6 +207,6 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol) data, is_pad = item["index"], item["index_is_pad"] assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values" - assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values" - - + assert torch.equal( + is_pad, torch.tensor([True, False, False, True, True]) + ), "Padding does not match expected values" diff --git a/tests/test_envs.py b/tests/test_envs.py index d25231b..33928a6 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,49 +1,37 @@ import importlib + +import gymnasium as gym import pytest import torch -from lerobot.common.datasets.factory import make_dataset -import gymnasium as gym from gymnasium.utils.env_checker import check_env +import lerobot +from lerobot.common.datasets.factory import make_dataset from lerobot.common.envs.factory import make_env -from lerobot.common.utils import init_hydra_config - from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.utils.utils import init_hydra_config -from .utils import DEVICE, DEFAULT_CONFIG_PATH +from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env + +OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] -@pytest.mark.parametrize( - "env_name, task, obs_type", - [ - # ("AlohaInsertion-v0", "state"), - ("aloha", "AlohaInsertion-v0", "pixels"), - ("aloha", "AlohaInsertion-v0", "pixels_agent_pos"), - ("aloha", "AlohaTransferCube-v0", "pixels"), - ("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"), - ("xarm", "XarmLift-v0", "state"), - ("xarm", "XarmLift-v0", "pixels"), - ("xarm", "XarmLift-v0", "pixels_agent_pos"), - ("pusht", "PushT-v0", "state"), - ("pusht", "PushT-v0", "pixels"), - ("pusht", "PushT-v0", "pixels_agent_pos"), - ], -) -def test_env(env_name, task, obs_type): +@pytest.mark.parametrize("obs_type", OBS_TYPES) +@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs) +@require_env +def test_env(env_name, env_task, obs_type): + if env_name == "aloha" and obs_type == "state": + pytest.skip("`state` observations not available for aloha") + package_name = f"gym_{env_name}" importlib.import_module(package_name) - env = gym.make(f"{package_name}/{task}", obs_type=obs_type) + env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type) check_env(env.unwrapped, skip_render_check=True) env.close() -@pytest.mark.parametrize( - "env_name", - [ - "pusht", - "xarm", - "aloha", - ], -) + +@pytest.mark.parametrize("env_name", lerobot.available_envs) +@require_env def test_factory(env_name): cfg = init_hydra_config( DEFAULT_CONFIG_PATH, diff --git a/tests/test_examples.py b/tests/test_examples.py index 0ca1f21..a3f90cf 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -1,5 +1,5 @@ -from pathlib import Path import subprocess +from pathlib import Path def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str: @@ -10,7 +10,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s def _run_script(path): - subprocess.run(['python', path], check=True) + subprocess.run(["python", path], check=True) def test_example_1(): @@ -33,7 +33,7 @@ def test_examples_4_and_3(): path = "examples/4_train_policy.py" - with open(path, "r") as file: + with open(path) as file: file_contents = file.read() # Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers. @@ -55,7 +55,7 @@ def test_examples_4_and_3(): path = "examples/3_evaluate_pretrained_policy.py" - with open(path, "r") as file: + with open(path) as file: file_contents = file.read() # Do less evals, use CPU, and use the local model. @@ -74,4 +74,4 @@ def test_examples_4_and_3(): ], ) - assert Path(f"outputs/train/example_pusht_diffusion").exists() + assert Path("outputs/train/example_pusht_diffusion").exists() diff --git a/tests/test_policies.py b/tests/test_policies.py index f53e402..ab679fc 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -1,16 +1,18 @@ import pytest import torch +from lerobot.common.datasets.factory import make_dataset from lerobot.common.datasets.utils import cycle +from lerobot.common.envs.factory import make_env from lerobot.common.envs.utils import postprocess_action, preprocess_observation from lerobot.common.policies.factory import make_policy from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.envs.factory import make_env -from lerobot.common.datasets.factory import make_dataset -from lerobot.common.utils import init_hydra_config -from .utils import DEVICE, DEFAULT_CONFIG_PATH +from lerobot.common.utils.utils import init_hydra_config + +from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env +# TODO(aliberts): refactor using lerobot/__init__.py variables @pytest.mark.parametrize( "env_name,policy_name,extra_overrides", [ @@ -21,10 +23,9 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH ("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]), ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]), ("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]), - # TODO(aliberts): xarm not working with diffusion - # ("xarm", "diffusion", []), ], ) +@require_env def test_policy(env_name, policy_name, extra_overrides): """ Tests: diff --git a/tests/utils.py b/tests/utils.py index 6169c3b..f3fe579 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,37 @@ -import os +import pytest +import torch + +from lerobot.common.utils.import_utils import is_package_available # Pass this as the first argument to init_hydra_config. DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml" -DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda") +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def require_env(func): + """ + Decorator that skips the test if the required environment package is not installed. + As it need 'env_name' in args, it also checks whether it is provided as an argument. + """ + from functools import wraps + + @wraps(func) + def wrapper(*args, **kwargs): + # Determine if 'env_name' is provided and extract its value + arg_names = func.__code__.co_varnames[: func.__code__.co_argcount] + if "env_name" in arg_names: + # Get the index of 'env_name' and retrieve the value from args + index = arg_names.index("env_name") + env_name = args[index] if len(args) > index else kwargs.get("env_name") + else: + raise ValueError("Function does not have 'env_name' as an argument.") + + # Perform the package check + package_name = f"gym_{env_name}" + if not is_package_available(package_name): + pytest.skip(f"gym-{env_name} not installed") + + return func(*args, **kwargs) + + return wrapper