Tests cleaning & simplification (#81)

This commit is contained in:
Simon Alibert
2024-04-18 14:47:42 +02:00
committed by GitHub
parent 0928afd37d
commit 7ad1909641
24 changed files with 277 additions and 157 deletions

View File

@@ -1,53 +1,60 @@
"""
This test verifies that all environments, datasets, policies listed in `lerobot/__init__.py` can be sucessfully
imported and that their class attributes (eg. `available_datasets`, `name`, `available_tasks`) are valid.
When implementing a new dataset (e.g. `AlohaDataset`), policy (e.g. `DiffusionPolicy`), or environment, follow these steps:
- Set the required class attributes: `available_datasets`.
- Set the required class attributes: `name`.
- Update variables in `lerobot/__init__.py` (e.g. `available_envs`, `available_datasets_per_envs`, `available_policies`)
- Update variables in `tests/test_available.py` by importing your new class
"""
import importlib
import pytest
import lerobot
import gymnasium as gym
from lerobot.common.datasets.xarm import XarmDataset
import gymnasium as gym
import pytest
import lerobot
from lerobot.common.datasets.aloha import AlohaDataset
from lerobot.common.datasets.pusht import PushtDataset
from lerobot.common.datasets.xarm import XarmDataset
from lerobot.common.policies.act.modeling_act import ActionChunkingTransformerPolicy
from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.common.policies.tdmpc.policy import TDMPCPolicy
from tests.utils import require_env
def test_available():
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
@require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be sucessfully imported — if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry, gym_handle
@pytest.mark.parametrize(
"env_name, dataset_class",
[
("aloha", AlohaDataset),
("pusht", PushtDataset),
("xarm", XarmDataset),
],
)
def test_available_datasets(env_name, dataset_class):
"""
This test verifies that the class attribute `available_datasets` for all
dataset classes is consistent with those listed in `lerobot/__init__.py`.
"""
available_env_datasets = lerobot.available_datasets[env_name]
assert set(available_env_datasets) == set(
dataset_class.available_datasets
), f"{env_name=} {available_env_datasets=}"
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [
ActionChunkingTransformerPolicy,
DiffusionPolicy,
TDMPCPolicy,
]
dataset_class_per_env = {
"aloha": AlohaDataset,
"pusht": PushtDataset,
"xarm": XarmDataset,
}
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
for env_name in lerobot.available_envs:
for task_name in lerobot.available_tasks_per_env[env_name]:
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry.keys(), gym_handle
dataset_class = dataset_class_per_env[env_name]
available_datasets = lerobot.available_datasets_per_env[env_name]
assert set(available_datasets) == set(dataset_class.available_datasets), f"{env_name=} {available_datasets=}"

View File

@@ -1,33 +1,35 @@
import logging
import os
from pathlib import Path
import einops
import pytest
import torch
from lerobot.common.datasets.utils import compute_stats, get_stats_einops_patterns, load_previous_and_future_frames
from lerobot.common.transforms import Prod
from lerobot.common.utils import init_hydra_config
import logging
from lerobot.common.datasets.factory import make_dataset
from datasets import Dataset
from .utils import DEVICE, DEFAULT_CONFIG_PATH
@pytest.mark.parametrize(
"env_name,dataset_id,policy_name",
[
("xarm", "xarm_lift_medium", "tdmpc"),
("pusht", "pusht", "diffusion"),
("aloha", "aloha_sim_insertion_human", "act"),
("aloha", "aloha_sim_insertion_scripted", "act"),
("aloha", "aloha_sim_transfer_cube_human", "act"),
("aloha", "aloha_sim_transfer_cube_scripted", "act"),
],
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import (
compute_stats,
get_stats_einops_patterns,
load_previous_and_future_frames,
)
from lerobot.common.transforms import Prod
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE
@pytest.mark.parametrize("env_name, dataset_id, policy_name", lerobot.env_dataset_policy_triplets)
def test_factory(env_name, dataset_id, policy_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,
overrides=[f"env={env_name}", f"dataset_id={dataset_id}", f"policy={policy_name}", f"device={DEVICE}"]
overrides=[
f"env={env_name}",
f"dataset_id={dataset_id}",
f"policy={policy_name}",
f"device={DEVICE}",
],
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
@@ -51,7 +53,7 @@ def test_factory(env_name, dataset_id, policy_name):
(key, 3, True),
)
assert dataset.hf_dataset[key].dtype == torch.uint8, f"{key}"
# test number of dimensions
for key, ndim, required in keys_ndim_required:
if key not in item:
@@ -60,13 +62,13 @@ def test_factory(env_name, dataset_id, policy_name):
else:
logging.warning(f'Missing key in dataset: "{key}" not in {dataset}.')
continue
if delta_timestamps is not None and key in delta_timestamps:
assert item[key].ndim == ndim + 1, f"{key}"
assert item[key].shape[0] == len(delta_timestamps[key]), f"{key}"
else:
assert item[key].ndim == ndim, f"{key}"
if key in image_keys:
assert item[key].dtype == torch.float32, f"{key}"
# TODO(rcadene): we assume for now that image normalization takes place in the model
@@ -77,17 +79,16 @@ def test_factory(env_name, dataset_id, policy_name):
# test t,c,h,w
assert item[key].shape[1] == 3, f"{key}"
else:
# test c,h,w
# test c,h,w
assert item[key].shape[0] == 3, f"{key}"
if delta_timestamps is not None:
# test missing keys in delta_timestamps
for key in delta_timestamps:
assert key in item, f"{key}"
def test_compute_stats():
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
We compare with taking a straight min, mean, max, std of all the data in one pass (which we can do
@@ -95,20 +96,20 @@ def test_compute_stats():
"""
from lerobot.common.datasets.xarm import XarmDataset
DATA_DIR = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
data_dir = Path(os.environ["DATA_DIR"]) if "DATA_DIR" in os.environ else None
# get transform to convert images from uint8 [0,255] to float32 [0,1]
transform = Prod(in_keys=XarmDataset.image_keys, prod=1 / 255.0)
dataset = XarmDataset(
dataset_id="xarm_lift_medium",
root=DATA_DIR,
root=data_dir,
transform=transform,
)
# Note: we set the batch size to be smaller than the whole dataset to make sure we are testing batched
# computation of the statistics. While doing this, we also make sure it works when we don't divide the
# dataset into even batches.
# dataset into even batches.
computed_stats = compute_stats(dataset, batch_size=int(len(dataset) * 0.25))
# get einops patterns to aggregate batches and compute statistics
@@ -128,7 +129,9 @@ def test_compute_stats():
for k, pattern in stats_patterns.items():
expected_stats[k] = {}
expected_stats[k]["mean"] = einops.reduce(hf_dataset[k], pattern, "mean")
expected_stats[k]["std"] = torch.sqrt(einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean"))
expected_stats[k]["std"] = torch.sqrt(
einops.reduce((hf_dataset[k] - expected_stats[k]["mean"]) ** 2, pattern, "mean")
)
expected_stats[k]["min"] = einops.reduce(hf_dataset[k], pattern, "min")
expected_stats[k]["max"] = einops.reduce(hf_dataset[k], pattern, "max")
@@ -153,12 +156,14 @@ def test_compute_stats():
def test_load_previous_and_future_frames_within_tolerance():
hf_dataset = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
})
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
delta_timestamps = {"index": [-0.2, 0, 0.139]}
@@ -168,13 +173,16 @@ def test_load_previous_and_future_frames_within_tolerance():
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
hf_dataset = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
})
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
delta_timestamps = {"index": [-0.2, 0, 0.141]}
@@ -182,13 +190,16 @@ def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range(
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
hf_dataset = Dataset.from_dict({
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
})
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_data_index_from": [0, 0, 0, 0, 0],
"episode_data_index_to": [5, 5, 5, 5, 5],
}
)
hf_dataset = hf_dataset.with_format("torch")
item = hf_dataset[2]
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
@@ -196,6 +207,6 @@ def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range
item = load_previous_and_future_frames(item, hf_dataset, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(is_pad, torch.tensor([True, False, False, True, True])), "Padding does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"

View File

@@ -1,49 +1,37 @@
import importlib
import gymnasium as gym
import pytest
import torch
from lerobot.common.datasets.factory import make_dataset
import gymnasium as gym
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.envs.factory import make_env
from lerobot.common.utils import init_hydra_config
from lerobot.common.envs.utils import preprocess_observation
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
@pytest.mark.parametrize(
"env_name, task, obs_type",
[
# ("AlohaInsertion-v0", "state"),
("aloha", "AlohaInsertion-v0", "pixels"),
("aloha", "AlohaInsertion-v0", "pixels_agent_pos"),
("aloha", "AlohaTransferCube-v0", "pixels"),
("aloha", "AlohaTransferCube-v0", "pixels_agent_pos"),
("xarm", "XarmLift-v0", "state"),
("xarm", "XarmLift-v0", "pixels"),
("xarm", "XarmLift-v0", "pixels_agent_pos"),
("pusht", "PushT-v0", "state"),
("pusht", "PushT-v0", "pixels"),
("pusht", "PushT-v0", "pixels_agent_pos"),
],
)
def test_env(env_name, task, obs_type):
@pytest.mark.parametrize("obs_type", OBS_TYPES)
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
pytest.skip("`state` observations not available for aloha")
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
env = gym.make(f"{package_name}/{task}", obs_type=obs_type)
env = gym.make(f"{package_name}/{env_task}", obs_type=obs_type)
check_env(env.unwrapped, skip_render_check=True)
env.close()
@pytest.mark.parametrize(
"env_name",
[
"pusht",
"xarm",
"aloha",
],
)
@pytest.mark.parametrize("env_name", lerobot.available_envs)
@require_env
def test_factory(env_name):
cfg = init_hydra_config(
DEFAULT_CONFIG_PATH,

View File

@@ -1,5 +1,5 @@
from pathlib import Path
import subprocess
from pathlib import Path
def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> str:
@@ -10,7 +10,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
def _run_script(path):
subprocess.run(['python', path], check=True)
subprocess.run(["python", path], check=True)
def test_example_1():
@@ -33,7 +33,7 @@ def test_examples_4_and_3():
path = "examples/4_train_policy.py"
with open(path, "r") as file:
with open(path) as file:
file_contents = file.read()
# Do less steps, use smaller batch, use CPU, and don't complicate things with dataloader workers.
@@ -55,7 +55,7 @@ def test_examples_4_and_3():
path = "examples/3_evaluate_pretrained_policy.py"
with open(path, "r") as file:
with open(path) as file:
file_contents = file.read()
# Do less evals, use CPU, and use the local model.
@@ -74,4 +74,4 @@ def test_examples_4_and_3():
],
)
assert Path(f"outputs/train/example_pusht_diffusion").exists()
assert Path("outputs/train/example_pusht_diffusion").exists()

View File

@@ -1,16 +1,18 @@
import pytest
import torch
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.utils import cycle
from lerobot.common.envs.factory import make_env
from lerobot.common.envs.utils import postprocess_action, preprocess_observation
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.policy_protocol import Policy
from lerobot.common.envs.factory import make_env
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.utils import init_hydra_config
from .utils import DEVICE, DEFAULT_CONFIG_PATH
from lerobot.common.utils.utils import init_hydra_config
from .utils import DEFAULT_CONFIG_PATH, DEVICE, require_env
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@@ -21,10 +23,9 @@ from .utils import DEVICE, DEFAULT_CONFIG_PATH
("aloha", "act", ["env.task=AlohaInsertion-v0", "dataset_id=aloha_sim_insertion_scripted"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_human"]),
("aloha", "act", ["env.task=AlohaTransferCube-v0", "dataset_id=aloha_sim_transfer_cube_scripted"]),
# TODO(aliberts): xarm not working with diffusion
# ("xarm", "diffusion", []),
],
)
@require_env
def test_policy(env_name, policy_name, extra_overrides):
"""
Tests:

View File

@@ -1,6 +1,37 @@
import os
import pytest
import torch
from lerobot.common.utils.import_utils import is_package_available
# Pass this as the first argument to init_hydra_config.
DEFAULT_CONFIG_PATH = "lerobot/configs/default.yaml"
DEVICE = os.environ.get('LEROBOT_TESTS_DEVICE', "cuda")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
"""
from functools import wraps
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'env_name' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "env_name" in arg_names:
# Get the index of 'env_name' and retrieve the value from args
index = arg_names.index("env_name")
env_name = args[index] if len(args) > index else kwargs.get("env_name")
else:
raise ValueError("Function does not have 'env_name' as an argument.")
# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs)
return wrapper