Merge remote-tracking branch 'origin/main' into user/rcadene/2024_06_01_custom_visualize_dataset

This commit is contained in:
Remi Cadene
2024-06-13 15:57:01 +00:00
34 changed files with 1750 additions and 496 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:36f50697dacc82d52d1799dbc53c6c2fb722b9c0bd5bfa90a92dfa336591c74a
size 3686488

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d0e3b4bde97c34606536b655c1e6a23316c9157bd21dcbc73a97500fb985607f
size 40551392

View File

@@ -0,0 +1,86 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import torch
from safetensors.torch import save_file
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.test_image_transforms import ARTIFACT_DIR, DATASET_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH
def save_default_config_transform(original_frame: torch.Tensor, output_dir: Path):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
with seeded_context(1337):
img_tf = default_tf(original_frame)
save_file({"default": img_tf}, output_dir / "default_transforms.safetensors")
def save_single_transforms(original_frame: torch.Tensor, output_dir: Path):
transforms = {
"brightness": [(0.5, 0.5), (2.0, 2.0)],
"contrast": [(0.5, 0.5), (2.0, 2.0)],
"saturation": [(0.5, 0.5), (2.0, 2.0)],
"hue": [(-0.25, -0.25), (0.25, 0.25)],
"sharpness": [(0.5, 0.5), (2.0, 2.0)],
}
frames = {"original_frame": original_frame}
for transform, values in transforms.items():
for min_max in values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
frames[key] = tf(original_frame)
save_file(frames, output_dir / "single_transforms.safetensors")
def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,260 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, require_x86_64_kernel
ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img():
dataset = LeRobotDataset(DATASET_REPO_ID)
return dataset[0][dataset.camera_keys[0]]
@pytest.fixture
def img_random():
return torch.rand(3, 480, 640)
@pytest.fixture
def color_jitters():
return [
v2.ColorJitter(brightness=0.5),
v2.ColorJitter(contrast=0.5),
v2.ColorJitter(saturation=0.5),
]
@pytest.fixture
def single_transforms():
return load_file(ARTIFACT_DIR / "single_transforms.safetensors")
@pytest.fixture
def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img):
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img), img)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img, min_max):
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img, min_max):
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img, min_max):
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img, min_max):
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img, min_max):
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
def test_get_image_transforms_max_num_transforms(img):
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=False,
)
tf_expected = v2.Compose(
[
v2.ColorJitter(brightness=(0.5, 0.5)),
v2.ColorJitter(contrast=(0.5, 0.5)),
v2.ColorJitter(saturation=(0.5, 0.5)),
v2.ColorJitter(hue=(0.5, 0.5)),
SharpnessJitter(sharpness=(0.5, 0.5)),
]
)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img):
out_imgs = []
tf = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
saturation_min_max=(0.5, 0.5),
hue_min_max=(0.5, 0.5),
sharpness_min_max=(0.5, 0.5),
random_order=True,
)
with seeded_context(1337):
for _ in range(10):
out_imgs.append(tf(img))
for i in range(1, len(out_imgs)):
with pytest.raises(AssertionError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.parametrize(
"transform, min_max_values",
[
("brightness", [(0.5, 0.5), (2.0, 2.0)]),
("contrast", [(0.5, 0.5), (2.0, 2.0)]),
("saturation", [(0.5, 0.5), (2.0, 2.0)]),
("hue", [(-0.25, -0.25), (0.25, 0.25)]),
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
for min_max in min_max_values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
actual = tf(img)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
expected = single_transforms[key]
torch.testing.assert_close(actual, expected)
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
brightness_weight=cfg_tf.brightness.weight,
brightness_min_max=cfg_tf.brightness.min_max,
contrast_weight=cfg_tf.contrast.weight,
contrast_min_max=cfg_tf.contrast.min_max,
saturation_weight=cfg_tf.saturation.weight,
saturation_min_max=cfg_tf.saturation.min_max,
hue_weight=cfg_tf.hue.weight,
hue_min_max=cfg_tf.hue.min_max,
sharpness_weight=cfg_tf.sharpness.weight,
sharpness_min_max=cfg_tf.sharpness.min_max,
max_num_transforms=cfg_tf.max_num_transforms,
random_order=cfg_tf.random_order,
)
with seeded_context(1337):
actual = default_tf(img)
expected = default_transforms["default"]
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img):
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img)
p_horz, _ = p
if p_horz:
torch.testing.assert_close(actual, F.horizontal_flip(img))
else:
torch.testing.assert_close(actual, F.vertical_flip(img))
def test_random_subset_apply_random_order(img):
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# applies them in random order, we can use a fixed order to compute the expected value.
actual = random_order(img)
expected = v2.Compose(flips)(img)
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img):
transform = RandomSubsetApply(color_jitters)
output = transform(img)
assert output.shape == img.shape
def test_random_subset_apply_probability_length_mismatch(color_jitters):
with pytest.raises(ValueError):
RandomSubsetApply(color_jitters, p=[0.5, 0.5])
@pytest.mark.parametrize("n_subset", [0, 5])
def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
with pytest.raises(ValueError):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img):
tf = SharpnessJitter((0.1, 2.0))
output = tf(img)
assert output.shape == img.shape
def test_sharpness_jitter_valid_range_float(img):
tf = SharpnessJitter(0.5)
output = tf(img)
assert output.shape == img.shape
def test_sharpness_jitter_invalid_range_min_negative():
with pytest.raises(ValueError):
SharpnessJitter((-0.1, 2.0))
def test_sharpness_jitter_invalid_range_max_smaller():
with pytest.raises(ValueError):
SharpnessJitter((2.0, 0.1))

View File

@@ -0,0 +1,352 @@
"""
This file contains generic tests to ensure that nothing breaks if we modify the push_dataset_to_hub API.
Also, this file contains backward compatibility tests. Because they are slow and require to download the raw datasets,
we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
from pathlib import Path
import numpy as np
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
from lerobot.common.datasets.video_utils import encode_video_frames
from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub
from tests.utils import require_package_arg
def _mock_download_raw_pusht(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "pusht_cchi_v7_replay.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/action", shape=(num_frames, 1), chunks=(num_frames, 1), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/img",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/n_contacts", shape=(num_frames, 2), chunks=(num_frames, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/state", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/keypoint", shape=(num_frames, 9, 2), chunks=(num_frames, 9, 2), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/action"][:] = np.random.randn(num_frames, 1)
zarr_data["data/img"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/n_contacts"][:] = np.random.randn(num_frames, 2)
zarr_data["data/state"][:] = np.random.randn(num_frames, 5)
zarr_data["data/keypoint"][:] = np.random.randn(num_frames, 9, 2)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_umi(raw_dir, num_frames=4, num_episodes=3):
import zarr
raw_dir.mkdir(parents=True, exist_ok=True)
zarr_path = raw_dir / "cup_in_the_wild.zarr"
store = zarr.DirectoryStore(zarr_path)
zarr_data = zarr.group(store=store)
zarr_data.create_dataset(
"data/camera0_rgb",
shape=(num_frames, 96, 96, 3),
chunks=(num_frames, 96, 96, 3),
dtype=np.uint8,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_end_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_demo_start_pose",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_eef_pos", shape=(num_frames, 5), chunks=(num_frames, 5), dtype=np.float32, overwrite=True
)
zarr_data.create_dataset(
"data/robot0_eef_rot_axis_angle",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"data/robot0_gripper_width",
shape=(num_frames, 5),
chunks=(num_frames, 5),
dtype=np.float32,
overwrite=True,
)
zarr_data.create_dataset(
"meta/episode_ends", shape=(num_episodes,), chunks=(num_episodes,), dtype=np.int32, overwrite=True
)
zarr_data["data/camera0_rgb"][:] = np.random.randint(0, 255, size=(num_frames, 96, 96, 3), dtype=np.uint8)
zarr_data["data/robot0_demo_end_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_demo_start_pose"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_pos"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_eef_rot_axis_angle"][:] = np.random.randn(num_frames, 5)
zarr_data["data/robot0_gripper_width"][:] = np.random.randn(num_frames, 5)
zarr_data["meta/episode_ends"][:] = np.array([1, 3, 4])
store.close()
def _mock_download_raw_xarm(raw_dir, num_frames=4):
import pickle
dataset_dict = {
"observations": {
"rgb": np.random.randint(0, 255, size=(num_frames, 3, 84, 84), dtype=np.uint8),
"state": np.random.randn(num_frames, 4),
},
"actions": np.random.randn(num_frames, 3),
"rewards": np.random.randn(num_frames),
"masks": np.random.randn(num_frames),
"dones": np.array([False, True, True, True]),
}
raw_dir.mkdir(parents=True, exist_ok=True)
pkl_path = raw_dir / "buffer.pkl"
with open(pkl_path, "wb") as f:
pickle.dump(dataset_dict, f)
def _mock_download_raw_aloha(raw_dir, num_frames=6, num_episodes=3):
import h5py
for ep_idx in range(num_episodes):
raw_dir.mkdir(parents=True, exist_ok=True)
path_h5 = raw_dir / f"episode_{ep_idx}.hdf5"
with h5py.File(str(path_h5), "w") as f:
f.create_dataset("action", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qpos", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset("observations/qvel", data=np.random.randn(num_frames // num_episodes, 14))
f.create_dataset(
"observations/images/top",
data=np.random.randint(
0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8
),
)
def _mock_download_raw_dora(raw_dir, num_frames=6, num_episodes=3, fps=30):
from datetime import datetime, timedelta, timezone
import pandas
def write_parquet(key, timestamps, values):
data = {
"timestamp_utc": timestamps,
key: values,
}
df = pandas.DataFrame(data)
raw_dir.mkdir(parents=True, exist_ok=True)
df.to_parquet(raw_dir / f"{key}.parquet", engine="pyarrow")
episode_indices = [None, None, -1, None, None, -1, None, None, -1]
episode_indices_mapping = [0, 0, 0, 1, 1, 1, 2, 2, 2]
frame_indices = [0, 1, -1, 0, 1, -1, 0, 1, -1]
cam_key = "observation.images.cam_high"
timestamps = []
actions = []
states = []
frames = []
# `+ num_episodes`` for buffer frames associated to episode_index=-1
for i, frame_idx in enumerate(frame_indices):
t_utc = datetime.now(timezone.utc) + timedelta(seconds=i / fps)
action = np.random.randn(21).tolist()
state = np.random.randn(21).tolist()
ep_idx = episode_indices_mapping[i]
frame = [{"path": f"videos/{cam_key}_episode_{ep_idx:06d}.mp4", "timestamp": frame_idx / fps}]
timestamps.append(t_utc)
actions.append(action)
states.append(state)
frames.append(frame)
write_parquet(cam_key, timestamps, frames)
write_parquet("observation.state", timestamps, states)
write_parquet("action", timestamps, actions)
write_parquet("episode_index", timestamps, episode_indices)
# write fake mp4 file for each episode
for ep_idx in range(num_episodes):
imgs_array = np.random.randint(0, 255, size=(num_frames // num_episodes, 480, 640, 3), dtype=np.uint8)
tmp_imgs_dir = raw_dir / "tmp_images"
save_images_concurrently(imgs_array, tmp_imgs_dir)
fname = f"{cam_key}_episode_{ep_idx:06d}.mp4"
video_path = raw_dir / "videos" / fname
encode_video_frames(tmp_imgs_dir, video_path, fps)
def _mock_download_raw(raw_dir, repo_id):
if "wrist_gripper" in repo_id:
_mock_download_raw_dora(raw_dir)
elif "aloha" in repo_id:
_mock_download_raw_aloha(raw_dir)
elif "pusht" in repo_id:
_mock_download_raw_pusht(raw_dir)
elif "xarm" in repo_id:
_mock_download_raw_xarm(raw_dir)
elif "umi" in repo_id:
_mock_download_raw_umi(raw_dir)
else:
raise ValueError(repo_id)
def test_push_dataset_to_hub_invalid_repo_id(tmpdir):
with pytest.raises(ValueError):
push_dataset_to_hub(Path(tmpdir), "raw_format", "invalid_repo_id")
def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
tmpdir = Path(tmpdir)
out_dir = tmpdir / "out"
raw_dir = tmpdir / "raw"
# mkdir to skip download
raw_dir.mkdir(parents=True, exist_ok=True)
with pytest.raises(ValueError):
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format="some_format",
repo_id="user/dataset",
local_dir=out_dir,
force_override=False,
)
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id",
[
(["gym-pusht"], "pusht_zarr", "lerobot/pusht"),
(None, "xarm_pkl", "lerobot/xarm_lift_medium"),
(None, "aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
(["imagecodecs"], "umi_zarr", "lerobot/umi_cup_in_the_wild"),
(None, "dora_parquet", "cadene/wrist_gripper"),
],
)
@require_package_arg
def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_id):
num_episodes = 3
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{repo_id}_raw"
_mock_download_raw(raw_dir, repo_id)
local_dir = tmpdir / repo_id
lerobot_dataset = push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
)
# minimal generic tests on the local directory containing LeRobotDataset
assert (local_dir / "meta_data" / "info.json").exists()
assert (local_dir / "meta_data" / "stats.safetensors").exists()
assert (local_dir / "meta_data" / "episode_data_index.safetensors").exists()
for i in range(num_episodes):
for cam_key in lerobot_dataset.camera_keys:
assert (local_dir / "videos" / f"{cam_key}_episode_{i:06d}.mp4").exists()
assert (local_dir / "train" / "dataset_info.json").exists()
assert (local_dir / "train" / "state.json").exists()
assert len(list((local_dir / "train").glob("*.arrow"))) > 0
# minimal generic tests on the item
item = lerobot_dataset[0]
assert "index" in item
assert "episode_index" in item
assert "timestamp" in item
for cam_key in lerobot_dataset.camera_keys:
assert cam_key in item
@pytest.mark.parametrize(
"raw_format, repo_id",
[
# TODO(rcadene): add raw dataset test artifacts
("pusht_zarr", "lerobot/pusht"),
("xarm_pkl", "lerobot/xarm_lift_medium"),
("aloha_hdf5", "lerobot/aloha_sim_insertion_scripted"),
("umi_zarr", "lerobot/umi_cup_in_the_wild"),
("dora_parquet", "cadene/wrist_gripper"),
],
)
@pytest.mark.skip(
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")
tmpdir = Path(tmpdir)
raw_dir = tmpdir / f"{dataset_id}_raw"
local_dir = tmpdir / repo_id
push_dataset_to_hub(
raw_dir=raw_dir,
raw_format=raw_format,
repo_id=repo_id,
push_to_hub=False,
local_dir=local_dir,
force_override=False,
cache_dir=tmpdir / "cache",
episodes=[0],
)
ds_actual = LeRobotDataset(repo_id, root=tmpdir)
ds_reference = LeRobotDataset(repo_id)
assert len(ds_reference.hf_dataset) == len(ds_actual.hf_dataset)
def check_same_items(item1, item2):
assert item1.keys() == item2.keys(), "Keys mismatch"
for key in item1:
if isinstance(item1[key], torch.Tensor) and isinstance(item2[key], torch.Tensor):
assert torch.equal(item1[key], item2[key]), f"Mismatch found in key: {key}"
else:
assert item1[key] == item2[key], f"Mismatch found in key: {key}"
for i in range(len(ds_reference.hf_dataset)):
item_reference = ds_reference.hf_dataset[i]
item_actual = ds_actual.hf_dataset[i]
check_same_items(item_reference, item_actual)

View File

@@ -13,6 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
@@ -30,3 +32,20 @@ def test_visualize_dataset(tmpdir, repo_id):
serve=False,
)
assert rrd_path.exists()
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
def test_visualize_local_dataset(tmpdir, repo_id, root):
rrd_path = visualize_dataset(
repo_id,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
root=root,
)
assert rrd_path.exists()

View File

@@ -76,6 +76,7 @@ def require_env(func):
"""
Decorator that skips the test if the required environment package is not installed.
As it need 'env_name' in args, it also checks whether it is provided as an argument.
If 'env_name' is None, this check is skipped.
"""
@wraps(func)
@@ -91,7 +92,7 @@ def require_env(func):
# Perform the package check
package_name = f"gym_{env_name}"
if not is_package_available(package_name):
if env_name is not None and not is_package_available(package_name):
pytest.skip(f"gym-{env_name} not installed")
return func(*args, **kwargs)
@@ -99,6 +100,38 @@ def require_env(func):
return wrapper
def require_package_arg(func):
"""
Decorator that skips the test if the required package is not installed.
This is similar to `require_env` but more general in that it can check any package (not just environments).
As it need 'required_packages' in args, it also checks whether it is provided as an argument.
If 'required_packages' is None, this check is skipped.
"""
@wraps(func)
def wrapper(*args, **kwargs):
# Determine if 'required_packages' is provided and extract its value
arg_names = func.__code__.co_varnames[: func.__code__.co_argcount]
if "required_packages" in arg_names:
# Get the index of 'required_packages' and retrieve the value from args
index = arg_names.index("required_packages")
required_packages = args[index] if len(args) > index else kwargs.get("required_packages")
else:
raise ValueError("Function does not have 'required_packages' as an argument.")
if required_packages is None:
return func(*args, **kwargs)
# Perform the package check
for package in required_packages:
if not is_package_available(package):
pytest.skip(f"{package} not installed")
return func(*args, **kwargs)
return wrapper
def require_package(package_name):
"""
Decorator that skips the test if the specified package is not installed.