Dataset v2.0 (#461)

Co-authored-by: Remi <remi.cadene@huggingface.co>
This commit is contained in:
Simon Alibert
2024-11-29 19:04:00 +01:00
committed by GitHub
parent 96c7052777
commit 32eb0cec8f
71 changed files with 6115 additions and 2235 deletions

View File

@@ -23,6 +23,13 @@ from lerobot import available_cameras, available_motors, available_robots
from lerobot.common.utils.utils import init_hydra_config
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
# Import fixture modules as plugins
pytest_plugins = [
"tests.fixtures.dataset_factories",
"tests.fixtures.files",
"tests.fixtures.hub",
]
def pytest_collection_finish():
print(f"\nTesting with {DEVICE=}")

29
tests/fixtures/constants.py vendored Normal file
View File

@@ -0,0 +1,29 @@
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
DUMMY_REPO_ID = "dummy/repo"
DUMMY_ROBOT_TYPE = "dummy_robot"
DUMMY_MOTOR_FEATURES = {
"action": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
"state": {
"dtype": "float32",
"shape": (6,),
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
},
}
DUMMY_CAMERA_FEATURES = {
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
}
DEFAULT_FPS = 30
DUMMY_VIDEO_INFO = {
"video.fps": DEFAULT_FPS,
"video.codec": "av1",
"video.pix_fmt": "yuv420p",
"video.is_depth_map": False,
"has_audio": False,
}

396
tests/fixtures/dataset_factories.py vendored Normal file
View File

@@ -0,0 +1,396 @@
import random
from pathlib import Path
from unittest.mock import patch
import datasets
import numpy as np
import PIL.Image
import pytest
import torch
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_FEATURES,
DEFAULT_PARQUET_PATH,
DEFAULT_VIDEO_PATH,
get_hf_features_from_features,
hf_transform_to_torch,
)
from tests.fixtures.constants import (
DEFAULT_FPS,
DUMMY_CAMERA_FEATURES,
DUMMY_MOTOR_FEATURES,
DUMMY_REPO_ID,
DUMMY_ROBOT_TYPE,
DUMMY_VIDEO_INFO,
)
def get_task_index(task_dicts: dict, task: str) -> int:
tasks = {d["task_index"]: d["task"] for d in task_dicts}
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
return task_to_task_index[task]
@pytest.fixture(scope="session")
def img_tensor_factory():
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
return torch.rand((channels, height, width), dtype=dtype)
return _create_img_tensor
@pytest.fixture(scope="session")
def img_array_factory():
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
if np.issubdtype(dtype, np.unsignedinteger):
# Int array in [0, 255] range
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
elif np.issubdtype(dtype, np.floating):
# Float array in [0, 1] range
img_array = np.random.rand(height, width, channels).astype(dtype)
else:
raise ValueError(dtype)
return img_array
return _create_img_array
@pytest.fixture(scope="session")
def img_factory(img_array_factory):
def _create_img(height=100, width=100) -> PIL.Image.Image:
img_array = img_array_factory(height=height, width=width)
return PIL.Image.fromarray(img_array)
return _create_img
@pytest.fixture(scope="session")
def features_factory():
def _create_features(
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
if use_videos:
camera_ft = {
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
}
else:
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
return {
**motor_features,
**camera_ft,
**DEFAULT_FEATURES,
}
return _create_features
@pytest.fixture(scope="session")
def info_factory(features_factory):
def _create_info(
codebase_version: str = CODEBASE_VERSION,
fps: int = DEFAULT_FPS,
robot_type: str = DUMMY_ROBOT_TYPE,
total_episodes: int = 0,
total_frames: int = 0,
total_tasks: int = 0,
total_videos: int = 0,
total_chunks: int = 0,
chunks_size: int = DEFAULT_CHUNK_SIZE,
data_path: str = DEFAULT_PARQUET_PATH,
video_path: str = DEFAULT_VIDEO_PATH,
motor_features: dict = DUMMY_MOTOR_FEATURES,
camera_features: dict = DUMMY_CAMERA_FEATURES,
use_videos: bool = True,
) -> dict:
features = features_factory(motor_features, camera_features, use_videos)
return {
"codebase_version": codebase_version,
"robot_type": robot_type,
"total_episodes": total_episodes,
"total_frames": total_frames,
"total_tasks": total_tasks,
"total_videos": total_videos,
"total_chunks": total_chunks,
"chunks_size": chunks_size,
"fps": fps,
"splits": {},
"data_path": data_path,
"video_path": video_path if use_videos else None,
"features": features,
}
return _create_info
@pytest.fixture(scope="session")
def stats_factory():
def _create_stats(
features: dict[str] | None = None,
) -> dict:
stats = {}
for key, ft in features.items():
shape = ft["shape"]
dtype = ft["dtype"]
if dtype in ["image", "video"]:
stats[key] = {
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
}
else:
stats[key] = {
"max": np.full(shape, 1, dtype=dtype).tolist(),
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
"min": np.full(shape, 0, dtype=dtype).tolist(),
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
}
return stats
return _create_stats
@pytest.fixture(scope="session")
def tasks_factory():
def _create_tasks(total_tasks: int = 3) -> int:
tasks_list = []
for i in range(total_tasks):
task_dict = {"task_index": i, "task": f"Perform action {i}."}
tasks_list.append(task_dict)
return tasks_list
return _create_tasks
@pytest.fixture(scope="session")
def episodes_factory(tasks_factory):
def _create_episodes(
total_episodes: int = 3,
total_frames: int = 400,
tasks: dict | None = None,
multi_task: bool = False,
):
if total_episodes <= 0 or total_frames <= 0:
raise ValueError("num_episodes and total_length must be positive integers.")
if total_frames < total_episodes:
raise ValueError("total_length must be greater than or equal to num_episodes.")
if not tasks:
min_tasks = 2 if multi_task else 1
total_tasks = random.randint(min_tasks, total_episodes)
tasks = tasks_factory(total_tasks)
if total_episodes < len(tasks) and not multi_task:
raise ValueError("The number of tasks should be less than the number of episodes.")
# Generate random lengths that sum up to total_length
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
tasks_list = [task_dict["task"] for task_dict in tasks]
num_tasks_available = len(tasks_list)
episodes_list = []
remaining_tasks = tasks_list.copy()
for ep_idx in range(total_episodes):
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
if remaining_tasks:
for task in episode_tasks:
remaining_tasks.remove(task)
episodes_list.append(
{
"episode_index": ep_idx,
"tasks": episode_tasks,
"length": lengths[ep_idx],
}
)
return episodes_list
return _create_episodes
@pytest.fixture(scope="session")
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
def _create_hf_dataset(
features: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
fps: int = DEFAULT_FPS,
) -> datasets.Dataset:
if not tasks:
tasks = tasks_factory()
if not episodes:
episodes = episodes_factory()
if not features:
features = features_factory()
timestamp_col = np.array([], dtype=np.float32)
frame_index_col = np.array([], dtype=np.int64)
episode_index_col = np.array([], dtype=np.int64)
task_index = np.array([], dtype=np.int64)
for ep_dict in episodes:
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
episode_index_col = np.concatenate(
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
)
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
index_col = np.arange(len(episode_index_col))
robot_cols = {}
for key, ft in features.items():
if ft["dtype"] == "image":
robot_cols[key] = [
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
for _ in range(len(index_col))
]
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
hf_features = get_hf_features_from_features(features)
dataset = datasets.Dataset.from_dict(
{
**robot_cols,
"timestamp": timestamp_col,
"frame_index": frame_index_col,
"episode_index": episode_index_col,
"index": index_col,
"task_index": task_index,
},
features=hf_features,
)
dataset.set_transform(hf_transform_to_torch)
return dataset
return _create_hf_dataset
@pytest.fixture(scope="session")
def lerobot_dataset_metadata_factory(
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
mock_snapshot_download_factory,
):
def _create_lerobot_dataset_metadata(
root: Path,
repo_id: str = DUMMY_REPO_ID,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
local_files_only: bool = False,
) -> LeRobotDatasetMetadata:
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
tasks=tasks,
episodes=episodes,
)
with (
patch(
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
) as mock_get_hub_safe_version_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
return _create_lerobot_dataset_metadata
@pytest.fixture(scope="session")
def lerobot_dataset_factory(
info_factory,
stats_factory,
tasks_factory,
episodes_factory,
hf_dataset_factory,
mock_snapshot_download_factory,
lerobot_dataset_metadata_factory,
):
def _create_lerobot_dataset(
root: Path,
repo_id: str = DUMMY_REPO_ID,
total_episodes: int = 3,
total_frames: int = 150,
total_tasks: int = 1,
multi_task: bool = False,
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episode_dicts: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
**kwargs,
) -> LeRobotDataset:
if not info:
info = info_factory(
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
)
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episode_dicts:
episode_dicts = episodes_factory(
total_episodes=info["total_episodes"],
total_frames=info["total_frames"],
tasks=tasks,
multi_task=multi_task,
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
mock_snapshot_download = mock_snapshot_download_factory(
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
hf_dataset=hf_dataset,
)
mock_metadata = lerobot_dataset_metadata_factory(
root=root,
repo_id=repo_id,
info=info,
stats=stats,
tasks=tasks,
episodes=episode_dicts,
local_files_only=kwargs.get("local_files_only", False),
)
with (
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
patch(
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
) as mock_snapshot_download_patch,
):
mock_metadata_patch.return_value = mock_metadata
mock_snapshot_download_patch.side_effect = mock_snapshot_download
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
return _create_lerobot_dataset

114
tests/fixtures/files.py vendored Normal file
View File

@@ -0,0 +1,114 @@
import json
from pathlib import Path
import datasets
import jsonlines
import pyarrow.compute as pc
import pyarrow.parquet as pq
import pytest
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
@pytest.fixture(scope="session")
def info_path(info_factory):
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
if not info:
info = info_factory()
fpath = dir / INFO_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(info, f, indent=4, ensure_ascii=False)
return fpath
return _create_info_json_file
@pytest.fixture(scope="session")
def stats_path(stats_factory):
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
if not stats:
stats = stats_factory()
fpath = dir / STATS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with open(fpath, "w") as f:
json.dump(stats, f, indent=4, ensure_ascii=False)
return fpath
return _create_stats_json_file
@pytest.fixture(scope="session")
def tasks_path(tasks_factory):
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
if not tasks:
tasks = tasks_factory()
fpath = dir / TASKS_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(tasks)
return fpath
return _create_tasks_jsonl_file
@pytest.fixture(scope="session")
def episode_path(episodes_factory):
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
if not episodes:
episodes = episodes_factory()
fpath = dir / EPISODES_PATH
fpath.parent.mkdir(parents=True, exist_ok=True)
with jsonlines.open(fpath, "w") as writer:
writer.write_all(episodes)
return fpath
return _create_episodes_jsonl_file
@pytest.fixture(scope="session")
def single_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_single_episode_parquet(
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return fpath
return _create_single_episode_parquet
@pytest.fixture(scope="session")
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
def _create_multi_episode_parquet(
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
) -> Path:
if not info:
info = info_factory()
if hf_dataset is None:
hf_dataset = hf_dataset_factory()
data_path = info["data_path"]
chunks_size = info["chunks_size"]
total_episodes = info["total_episodes"]
for ep_idx in range(total_episodes):
ep_chunk = ep_idx // chunks_size
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
fpath.parent.mkdir(parents=True, exist_ok=True)
table = hf_dataset.data.table
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
pq.write_table(ep_table, fpath)
return dir / "data"
return _create_multi_episode_parquet

105
tests/fixtures/hub.py vendored Normal file
View File

@@ -0,0 +1,105 @@
from pathlib import Path
import datasets
import pytest
from huggingface_hub.utils import filter_repo_objects
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
from tests.fixtures.constants import LEROBOT_TEST_DIR
@pytest.fixture(scope="session")
def mock_snapshot_download_factory(
info_factory,
info_path,
stats_factory,
stats_path,
tasks_factory,
tasks_path,
episodes_factory,
episode_path,
single_episode_parquet_path,
hf_dataset_factory,
):
"""
This factory allows to patch snapshot_download such that when called, it will create expected files rather
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
"""
def _mock_snapshot_download_func(
info: dict | None = None,
stats: dict | None = None,
tasks: list[dict] | None = None,
episodes: list[dict] | None = None,
hf_dataset: datasets.Dataset | None = None,
):
if not info:
info = info_factory()
if not stats:
stats = stats_factory(features=info["features"])
if not tasks:
tasks = tasks_factory(total_tasks=info["total_tasks"])
if not episodes:
episodes = episodes_factory(
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
)
if not hf_dataset:
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
def _extract_episode_index_from_path(fpath: str) -> int:
path = Path(fpath)
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
return episode_index
else:
return None
def _mock_snapshot_download(
repo_id: str,
local_dir: str | Path | None = None,
allow_patterns: str | list[str] | None = None,
ignore_patterns: str | list[str] | None = None,
*args,
**kwargs,
) -> str:
if not local_dir:
local_dir = LEROBOT_TEST_DIR
# List all possible files
all_files = []
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
all_files.extend(meta_files)
data_files = []
for episode_dict in episodes:
ep_idx = episode_dict["episode_index"]
ep_chunk = ep_idx // info["chunks_size"]
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
data_files.append(data_path)
all_files.extend(data_files)
allowed_files = filter_repo_objects(
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
)
# Create allowed files
for rel_path in allowed_files:
if rel_path.startswith("data/"):
episode_index = _extract_episode_index_from_path(rel_path)
if episode_index is not None:
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
if rel_path == INFO_PATH:
_ = info_path(local_dir, info)
elif rel_path == STATS_PATH:
_ = stats_path(local_dir, stats)
elif rel_path == TASKS_PATH:
_ = tasks_path(local_dir, tasks)
elif rel_path == EPISODES_PATH:
_ = episode_path(local_dir, episodes)
else:
pass
return str(local_dir)
return _mock_snapshot_download
return _mock_snapshot_download_func

View File

@@ -76,7 +76,7 @@ def main():
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
output_dir = Path(ARTIFACT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
original_frame = dataset[0][dataset.camera_keys[0]]
original_frame = dataset[0][dataset.meta.camera_keys[0]]
save_single_transforms(original_frame, output_dir)
save_default_config_transform(original_frame, output_dir)

View File

@@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
)
set_global_seed(1337)
dataset = make_dataset(cfg)
policy = make_policy(cfg, dataset_stats=dataset.stats)
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
policy.train()
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)

View File

@@ -29,7 +29,6 @@ from unittest.mock import patch
import pytest
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
from lerobot.common.logger import Logger
from lerobot.common.policies.factory import make_policy
from lerobot.common.utils.utils import init_hydra_config
@@ -93,8 +92,9 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
mock_calibration_dir(calibration_dir)
overrides.append(f"calibration_dir={calibration_dir}")
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock)
record(
@@ -102,6 +102,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
fps=30,
root=root,
repo_id=repo_id,
single_task=single_task,
warmup_time_s=1,
episode_time_s=1,
num_episodes=2,
@@ -132,17 +133,18 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
env_name = "koch_real"
policy_name = "act_koch_real"
root = tmpdir / "data"
repo_id = "lerobot/debug"
eval_repo_id = "lerobot/eval_debug"
root = tmpdir / "data" / repo_id
single_task = "Do something."
robot = make_robot(robot_type, overrides=overrides, mock=mock)
dataset = record(
robot,
root,
repo_id,
single_task,
fps=1,
warmup_time_s=1,
warmup_time_s=0.5,
episode_time_s=1,
reset_time_s=1,
num_episodes=2,
@@ -153,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
display_cameras=False,
play_sounds=False,
)
assert dataset.num_episodes == 2
assert dataset.meta.total_episodes == 2
assert len(dataset) == 2
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
@@ -191,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
overrides=overrides,
)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
out_dir = tmpdir / "logger"
logger = Logger(cfg, out_dir, wandb_job_name="debug")
@@ -225,10 +227,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
else:
num_image_writer_processes = 0
record(
eval_repo_id = "lerobot/eval_debug"
eval_root = tmpdir / "data" / eval_repo_id
dataset = record(
robot,
root,
eval_root,
eval_repo_id,
single_task,
pretrained_policy_name_or_path,
warmup_time_s=1,
episode_time_s=1,
@@ -265,51 +271,36 @@ def test_resume_record(tmpdir, request, robot_type, mock):
robot = make_robot(robot_type, overrides=overrides, mock=mock)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=1,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
record_kwargs = {
"robot": robot,
"root": root,
"repo_id": repo_id,
"single_task": single_task,
"fps": 1,
"warmup_time_s": 0,
"episode_time_s": 1,
"push_to_hub": False,
"video": False,
"display_cameras": False,
"play_sounds": False,
"run_compute_stats": False,
"local_files_only": True,
"num_episodes": 1,
}
init_dataset_return_value = {}
dataset = record(**record_kwargs)
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
def wrapped_init_dataset(*args, **kwargs):
nonlocal init_dataset_return_value
init_dataset_return_value = init_dataset(*args, **kwargs)
return init_dataset_return_value
with pytest.raises(FileExistsError):
# Dataset already exists, but resume=False by default
record(**record_kwargs)
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
dataset = record(
robot,
root,
repo_id,
fps=1,
warmup_time_s=0,
episode_time_s=1,
num_episodes=2,
push_to_hub=False,
video=False,
display_cameras=False,
play_sounds=False,
run_compute_stats=False,
)
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
assert (
init_dataset_return_value["num_episodes"] == 2
), "`init_dataset` should load the previous episode"
dataset = record(**record_kwargs, resume=True)
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
@@ -328,23 +319,22 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = True
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
@@ -358,7 +348,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@@ -378,23 +367,22 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = False
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
fps=2,
root=root,
single_task=single_task,
repo_id=repo_id,
warmup_time_s=0,
episode_time_s=1,
@@ -407,7 +395,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
@@ -429,23 +416,22 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
overrides = []
robot = make_robot(robot_type, overrides=overrides, mock=mock)
with (
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
):
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
mock_events = {}
mock_events["exit_early"] = True
mock_events["rerecord_episode"] = False
mock_events["stop_recording"] = True
mock_listener.return_value = (None, mock_events)
root = Path(tmpdir) / "data"
repo_id = "lerobot/debug"
root = Path(tmpdir) / "data" / repo_id
single_task = "Do something."
dataset = record(
robot,
root,
repo_id,
single_task=single_task,
fps=1,
warmup_time_s=0,
episode_time_s=1,
@@ -459,5 +445,4 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
)
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
assert len(dataset) == 1, "`dataset` should contain only 1 frame"

View File

@@ -33,18 +33,72 @@ from lerobot.common.datasets.compute_stats import (
get_stats_einops_patterns,
)
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
from lerobot.common.datasets.lerobot_dataset import (
LeRobotDataset,
MultiLeRobotDataset,
)
from lerobot.common.datasets.utils import (
create_branch,
flatten_dict,
hf_transform_to_torch,
load_previous_and_future_frames,
unflatten_dict,
)
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
"""
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
objects have the same sets of attributes defined.
"""
# Instantiate both ways
robot = make_robot("koch", mock=True)
root_create = tmp_path / "create"
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
root_init = tmp_path / "init"
dataset_init = lerobot_dataset_factory(root=root_init)
# Access the '_hub_version' cached_property in both instances to force its creation
_ = dataset_init.meta._hub_version
_ = dataset_create.meta._hub_version
init_attr = set(vars(dataset_init).keys())
create_attr = set(vars(dataset_create).keys())
assert init_attr == create_attr
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
kwargs = {
"repo_id": DUMMY_REPO_ID,
"total_episodes": 10,
"total_frames": 400,
"episodes": [2, 5, 6],
}
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
assert dataset.repo_id == kwargs["repo_id"]
assert dataset.meta.total_episodes == kwargs["total_episodes"]
assert dataset.meta.total_frames == kwargs["total_frames"]
assert dataset.episodes == kwargs["episodes"]
assert dataset.num_episodes == len(kwargs["episodes"])
assert dataset.num_frames == len(dataset)
# TODO(aliberts):
# - [ ] test various attributes & state from init and create
# - [ ] test init with episodes and check num_frames
# - [ ] test add_frame
# - [ ] test add_episode
# - [ ] test consolidate
# - [ ] test push_to_hub
# - [ ] test smaller methods
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
lerobot.env_dataset_policy_triplets
@@ -67,7 +121,7 @@ def test_factory(env_name, repo_id, policy_name):
)
dataset = make_dataset(cfg)
delta_timestamps = dataset.delta_timestamps
camera_keys = dataset.camera_keys
camera_keys = dataset.meta.camera_keys
item = dataset[0]
@@ -117,6 +171,7 @@ def test_factory(env_name, repo_id, policy_name):
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_multilerobotdataset_frames():
"""Check that all dataset frames are incorporated."""
# Note: use the image variants of the dataset to make the test approx 3x faster.
@@ -130,7 +185,7 @@ def test_multilerobotdataset_frames():
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
dataset = MultiLeRobotDataset(repo_ids)
assert len(dataset) == sum(len(d) for d in sub_datasets)
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
@@ -149,6 +204,8 @@ def test_multilerobotdataset_frames():
assert torch.equal(sub_dataset_item[k], dataset_item[k])
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_compute_stats_on_xarm():
"""Check that the statistics are computed correctly according to the stats_patterns property.
@@ -197,7 +254,7 @@ def test_compute_stats_on_xarm():
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
# load stats used during training which are expected to match the ones returned by computed_stats
loaded_stats = dataset.stats # noqa: F841
loaded_stats = dataset.meta.stats # noqa: F841
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
# # test loaded stats match expected stats
@@ -208,72 +265,7 @@ def test_compute_stats_on_xarm():
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
def test_load_previous_and_future_frames_within_tolerance():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.139]}
tol = 0.04
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
assert not is_pad.any(), "Unexpected padding detected"
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.2, 0, 0.141]}
tol = 0.04
item = hf_dataset[2]
with pytest.raises(AssertionError):
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
hf_dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
"index": [0, 1, 2, 3, 4],
"episode_index": [0, 0, 0, 0, 0],
}
)
hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([5]),
}
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
tol = 0.04
item = hf_dataset[2]
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
data, is_pad = item["index"], item["index_is_pad"]
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
assert torch.equal(
is_pad, torch.tensor([True, False, False, True, True])
), "Padding does not match expected values"
# TODO(aliberts): Move to more appropriate location
def test_flatten_unflatten_dict():
d = {
"obs": {
@@ -297,6 +289,7 @@ def test_flatten_unflatten_dict():
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"repo_id",
[
@@ -368,6 +361,7 @@ def test_backward_compatibility(repo_id):
# load_and_compare(i - 1)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_aggregate_stats():
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
with seeded_context(0):

View File

@@ -0,0 +1,256 @@
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
check_delta_timestamps,
check_timestamps_sync,
get_delta_indices,
hf_transform_to_torch,
)
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
@pytest.fixture(scope="module")
def synced_hf_dataset_factory(hf_dataset_factory):
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
return hf_dataset_factory(fps=fps)
return _create_synced_hf_dataset
@pytest.fixture(scope="module")
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just outside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
return _create_unsynced_hf_dataset
@pytest.fixture(scope="module")
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
hf_dataset = synced_hf_dataset_factory(fps=fps)
features = hf_dataset.features
df = hf_dataset.to_pandas()
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
# Modify a single timestamp just inside tolerance
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
return unsynced_hf_dataset
return _create_slightly_off_hf_dataset
@pytest.fixture(scope="module")
def valid_delta_timestamps_factory():
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
return delta_timestamps
return _create_valid_delta_timestamps
@pytest.fixture(scope="module")
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_invalid_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just outside tolerance
for key in keys:
delta_timestamps[key][3] += tolerance_s * 1.1
return delta_timestamps
return _create_invalid_delta_timestamps
@pytest.fixture(scope="module")
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
def _create_slightly_off_delta_timestamps(
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
) -> dict:
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
# Modify a single timestamp just inside tolerance
for key in delta_timestamps:
delta_timestamps[key][3] += tolerance_s * 0.9
delta_timestamps[key][-3] += tolerance_s * 0.9
return delta_timestamps
return _create_slightly_off_delta_timestamps
@pytest.fixture(scope="module")
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
return {key: list(range(-10, 10)) for key in keys}
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
synced_hf_dataset = synced_hf_dataset_factory(fps)
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
result = check_timestamps_sync(
hf_dataset=synced_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
with pytest.raises(ValueError):
check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
result = check_timestamps_sync(
hf_dataset=unsynced_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
result = check_timestamps_sync(
hf_dataset=slightly_off_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_timestamps_sync_single_timestamp():
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
fps = 30
tolerance_s = 1e-4
result = check_timestamps_sync(
hf_dataset=single_timestamp_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
@pytest.mark.skip("TODO: fix")
def test_check_timestamps_sync_empty_dataset():
fps = 30
tolerance_s = 1e-4
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
empty_hf_dataset.set_transform(hf_transform_to_torch)
episode_data_index = {
"to": torch.tensor([], dtype=torch.int64),
"from": torch.tensor([], dtype=torch.int64),
}
result = check_timestamps_sync(
hf_dataset=empty_hf_dataset,
episode_data_index=episode_data_index,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
result = check_delta_timestamps(
delta_timestamps=valid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
result = check_delta_timestamps(
delta_timestamps=slightly_off_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
with pytest.raises(ValueError):
check_delta_timestamps(
delta_timestamps=invalid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
fps = 30
tolerance_s = 1e-4
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
result = check_delta_timestamps(
delta_timestamps=invalid_delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
raise_value_error=False,
)
assert result is False
def test_check_delta_timestamps_empty():
delta_timestamps = {}
fps = 30
tolerance_s = 1e-4
result = check_delta_timestamps(
delta_timestamps=delta_timestamps,
fps=fps,
tolerance_s=tolerance_s,
)
assert result is True
def test_delta_indices(valid_delta_timestamps_factory, delta_indices):
fps = 30
delta_timestamps = valid_delta_timestamps_factory(fps)
expected_delta_indices = delta_indices
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
assert expected_delta_indices == actual_delta_indices

View File

@@ -13,12 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(aliberts): Mute logging for these tests
import io
import subprocess
import sys
from pathlib import Path
import pytest
from tests.fixtures.constants import DUMMY_REPO_ID
from tests.utils import require_package
@@ -29,6 +32,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
return text
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
def _run_script(path):
subprocess.run([sys.executable, path], check=True)
@@ -38,12 +42,26 @@ def _read_file(path):
return file.read()
def test_example_1():
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
def test_example_1(tmp_path, lerobot_dataset_factory):
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID)
path = "examples/1_load_lerobot_dataset.py"
_run_script(path)
file_contents = _read_file(path)
file_contents = _find_and_replace(
file_contents,
[
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
(
"LeRobotDataset(repo_id",
f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True",
),
],
)
exec(file_contents, {})
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
@require_package("gym_pusht")
def test_examples_basic2_basic3_advanced1():
"""
@@ -111,7 +129,8 @@ def test_examples_basic2_basic3_advanced1():
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
),
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"),
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"),
("num_workers=4", "num_workers=0"),
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
("batch_size=64", "batch_size=1"),

View File

@@ -15,15 +15,12 @@
# limitations under the License.
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
from torchvision.transforms.v2 import functional as F # noqa: N812
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
from lerobot.common.utils.utils import init_hydra_config, seeded_context
from lerobot.scripts.visualize_image_transforms import visualize_transforms
@@ -33,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
def load_png_to_tensor(path: Path):
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
@pytest.fixture
def img():
dataset = LeRobotDataset(DATASET_REPO_ID)
return dataset[0][dataset.camera_keys[0]]
@pytest.fixture
def img_random():
return torch.rand(3, 480, 640)
@pytest.fixture
def color_jitters():
return [
@@ -67,47 +49,54 @@ def default_transforms():
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
def test_get_image_transforms_no_transform(img):
def test_get_image_transforms_no_transform(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
torch.testing.assert_close(tf_actual(img), img)
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_brightness(img, min_max):
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
tf_expected = v2.ColorJitter(brightness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_contrast(img, min_max):
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
tf_expected = v2.ColorJitter(contrast=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_saturation(img, min_max):
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
tf_expected = v2.ColorJitter(saturation=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
def test_get_image_transforms_hue(img, min_max):
def test_get_image_transforms_hue(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
tf_expected = v2.ColorJitter(hue=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
def test_get_image_transforms_sharpness(img, min_max):
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
tf_expected = SharpnessJitter(sharpness=min_max)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
def test_get_image_transforms_max_num_transforms(img):
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
img_tensor = img_tensor_factory()
tf_actual = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@@ -125,12 +114,13 @@ def test_get_image_transforms_max_num_transforms(img):
SharpnessJitter(sharpness=(0.5, 0.5)),
]
)
torch.testing.assert_close(tf_actual(img), tf_expected(img))
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
@require_x86_64_kernel
def test_get_image_transforms_random_order(img):
def test_get_image_transforms_random_order(img_tensor_factory):
out_imgs = []
img_tensor = img_tensor_factory()
tf = get_image_transforms(
brightness_min_max=(0.5, 0.5),
contrast_min_max=(0.5, 0.5),
@@ -141,13 +131,14 @@ def test_get_image_transforms_random_order(img):
)
with seeded_context(1337):
for _ in range(10):
out_imgs.append(tf(img))
out_imgs.append(tf(img_tensor))
for i in range(1, len(out_imgs)):
with pytest.raises(AssertionError):
torch.testing.assert_close(out_imgs[0], out_imgs[i])
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"transform, min_max_values",
[
@@ -158,21 +149,24 @@ def test_get_image_transforms_random_order(img):
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
],
)
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
img_tensor = img_tensor_factory()
for min_max in min_max_values:
kwargs = {
f"{transform}_weight": 1.0,
f"{transform}_min_max": min_max,
}
tf = get_image_transforms(**kwargs)
actual = tf(img)
actual = tf(img_tensor)
key = f"{transform}_{min_max[0]}_{min_max[1]}"
expected = single_transforms[key]
torch.testing.assert_close(actual, expected)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@require_x86_64_kernel
def test_backward_compatibility_default_config(img, default_transforms):
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
img_tensor = img_tensor_factory()
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
cfg_tf = cfg.training.image_transforms
default_tf = get_image_transforms(
@@ -191,7 +185,7 @@ def test_backward_compatibility_default_config(img, default_transforms):
)
with seeded_context(1337):
actual = default_tf(img)
actual = default_tf(img_tensor)
expected = default_transforms["default"]
@@ -199,33 +193,36 @@ def test_backward_compatibility_default_config(img, default_transforms):
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
def test_random_subset_apply_single_choice(p, img):
def test_random_subset_apply_single_choice(img_tensor_factory, p):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
actual = random_choice(img)
actual = random_choice(img_tensor)
p_horz, _ = p
if p_horz:
torch.testing.assert_close(actual, F.horizontal_flip(img))
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
else:
torch.testing.assert_close(actual, F.vertical_flip(img))
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
def test_random_subset_apply_random_order(img):
def test_random_subset_apply_random_order(img_tensor_factory):
img_tensor = img_tensor_factory()
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
# We can't really check whether the transforms are actually applied in random order. However,
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
# applies them in random order, we can use a fixed order to compute the expected value.
actual = random_order(img)
expected = v2.Compose(flips)(img)
actual = random_order(img_tensor)
expected = v2.Compose(flips)(img_tensor)
torch.testing.assert_close(actual, expected)
def test_random_subset_apply_valid_transforms(color_jitters, img):
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
img_tensor = img_tensor_factory()
transform = RandomSubsetApply(color_jitters)
output = transform(img)
assert output.shape == img.shape
output = transform(img_tensor)
assert output.shape == img_tensor.shape
def test_random_subset_apply_probability_length_mismatch(color_jitters):
@@ -239,16 +236,18 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
RandomSubsetApply(color_jitters, n_subset=n_subset)
def test_sharpness_jitter_valid_range_tuple(img):
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter((0.1, 2.0))
output = tf(img)
assert output.shape == img.shape
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_valid_range_float(img):
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
img_tensor = img_tensor_factory()
tf = SharpnessJitter(0.5)
output = tf(img)
assert output.shape == img.shape
output = tf(img_tensor)
assert output.shape == img_tensor.shape
def test_sharpness_jitter_invalid_range_min_negative():
@@ -261,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
SharpnessJitter((2.0, 0.1))
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"repo_id, n_examples",
[

359
tests/test_image_writer.py Normal file
View File

@@ -0,0 +1,359 @@
import queue
import time
from multiprocessing import queues
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from PIL import Image
from lerobot.common.datasets.image_writer import (
AsyncImageWriter,
image_array_to_image,
safe_stop_image_writer,
write_image,
)
DUMMY_IMAGE = "test_image.png"
def test_init_threading():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
try:
assert writer.num_processes == 0
assert writer.num_threads == 2
assert isinstance(writer.queue, queue.Queue)
assert len(writer.threads) == 2
assert len(writer.processes) == 0
assert all(t.is_alive() for t in writer.threads)
finally:
writer.stop()
def test_init_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
assert writer.num_processes == 2
assert writer.num_threads == 2
assert isinstance(writer.queue, queues.JoinableQueue)
assert len(writer.threads) == 0
assert len(writer.processes) == 2
assert all(p.is_alive() for p in writer.processes)
finally:
writer.stop()
def test_zero_threads():
with pytest.raises(ValueError):
AsyncImageWriter(num_processes=0, num_threads=0)
def test_image_array_to_image_rgb(img_array_factory):
img_array = img_array_factory(100, 100)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
def test_image_array_to_image_pytorch_format(img_array_factory):
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
@pytest.mark.skip("TODO: implement")
def test_image_array_to_image_single_channel(img_array_factory):
img_array = img_array_factory(channels=1)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "L"
def test_image_array_to_image_float_array(img_array_factory):
img_array = img_array_factory(dtype=np.float32)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
def test_image_array_to_image_out_of_bounds_float():
# Float array with values out of [0, 1]
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
result_image = image_array_to_image(img_array)
assert isinstance(result_image, Image.Image)
assert result_image.size == (100, 100)
assert result_image.mode == "RGB"
assert np.array(result_image).dtype == np.uint8
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
def test_write_image_numpy(tmp_path, img_array_factory):
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
write_image(image_array, fpath)
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
def test_write_image_image(tmp_path, img_factory):
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
write_image(image_pil, fpath)
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
assert np.array_equal(image_pil, saved_image)
def test_write_image_exception(tmp_path):
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
with patch("builtins.print") as mock_print:
write_image(image_array, fpath)
mock_print.assert_called()
assert not fpath.exists()
def test_save_image_numpy(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
finally:
writer.stop()
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(image_array, saved_image)
finally:
writer.stop()
def test_save_image_torch(tmp_path, img_tensor_factory):
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_tensor, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_tensor, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
assert np.array_equal(expected_image, saved_image)
finally:
writer.stop()
def test_save_image_pil(tmp_path, img_factory):
writer = AsyncImageWriter()
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_pil, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
finally:
writer.stop()
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
image_pil = img_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_pil, fpath)
writer.wait_until_done()
assert fpath.exists()
saved_image = Image.open(fpath)
assert list(saved_image.getdata()) == list(image_pil.getdata())
finally:
writer.stop()
def test_save_image_invalid_data(tmp_path):
writer = AsyncImageWriter()
try:
image_array = "invalid data"
fpath = tmp_path / DUMMY_IMAGE
fpath.parent.mkdir(parents=True, exist_ok=True)
with patch("builtins.print") as mock_print:
writer.save_image(image_array, fpath)
writer.wait_until_done()
mock_print.assert_called()
assert not fpath.exists()
finally:
writer.stop()
def test_save_image_after_stop(tmp_path, img_array_factory):
writer = AsyncImageWriter()
writer.stop()
image_array = img_array_factory()
fpath = tmp_path / DUMMY_IMAGE
writer.save_image(image_array, fpath)
time.sleep(1)
assert not fpath.exists()
def test_stop():
writer = AsyncImageWriter(num_processes=0, num_threads=2)
writer.stop()
assert not any(t.is_alive() for t in writer.threads)
def test_stop_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
assert not any(p.is_alive() for p in writer.processes)
def test_multiple_stops():
writer = AsyncImageWriter()
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_multiple_stops_multiprocessing():
writer = AsyncImageWriter(num_processes=2, num_threads=2)
writer.stop()
writer.stop() # Should not raise an exception
assert not any(t.is_alive() for t in writer.threads)
def test_wait_until_done(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=0, num_threads=4)
try:
num_images = 100
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
for i, fpath in enumerate(fpaths):
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(saved_image, image_arrays[i])
finally:
writer.stop()
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
writer = AsyncImageWriter(num_processes=2, num_threads=2)
try:
num_images = 100
image_arrays = [img_array_factory() for _ in range(num_images)]
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
fpath.parent.mkdir(parents=True, exist_ok=True)
writer.save_image(image_array, fpath)
writer.wait_until_done()
for i, fpath in enumerate(fpaths):
assert fpath.exists()
saved_image = np.array(Image.open(fpath))
assert np.array_equal(saved_image, image_arrays[i])
finally:
writer.stop()
def test_exception_handling(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
with (
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
pytest.raises(queue.Full) as exc_info,
):
writer.save_image(image_array, tmp_path / "test.png")
assert str(exc_info.value) == "Queue is full"
finally:
writer.stop()
def test_with_different_image_formats(tmp_path, img_array_factory):
writer = AsyncImageWriter()
try:
image_array = img_array_factory()
formats = ["png", "jpeg", "bmp"]
for fmt in formats:
fpath = tmp_path / f"test_image.{fmt}"
write_image(image_array, fpath)
assert fpath.exists()
finally:
writer.stop()
def test_safe_stop_image_writer_decorator():
class MockDataset:
def __init__(self):
self.image_writer = MagicMock(spec=AsyncImageWriter)
@safe_stop_image_writer
def function_that_raises_exception(dataset=None):
raise Exception("Test exception")
dataset = MockDataset()
with pytest.raises(Exception) as exc_info:
function_that_raises_exception(dataset=dataset)
assert str(exc_info.value) == "Test exception"
dataset.image_writer.stop.assert_called_once()
def test_main_process_time(tmp_path, img_tensor_factory):
writer = AsyncImageWriter()
try:
image_tensor = img_tensor_factory()
fpath = tmp_path / DUMMY_IMAGE
start_time = time.perf_counter()
writer.save_image(image_tensor, fpath)
end_time = time.perf_counter()
time_spent = end_time - start_time
# Might need to adjust this threshold depending on hardware
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
writer.wait_until_done()
assert fpath.exists()
finally:
writer.stop()

View File

@@ -19,11 +19,8 @@ from uuid import uuid4
import numpy as np
import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
from lerobot.common.datasets.utils import hf_transform_to_torch
# Some constants for OnlineBuffer tests.
data_key = "data"
@@ -212,29 +209,17 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
@pytest.mark.parametrize("online_dataset_size", [0, 4])
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
def test_compute_sampler_weights_trivial(
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
lerobot_dataset_factory,
tmp_path,
offline_dataset_size: int,
online_dataset_size: int,
online_sampling_ratio: float,
):
# Pass/skip the test if both datasets sizes are zero.
if offline_dataset_size + online_dataset_size == 0:
return
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
)
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
if offline_dataset_size == 0:
offline_dataset.episode_data_index = {}
else:
# Set up an episode_data_index with at least two episodes.
offline_dataset.episode_data_index = {
"from": torch.tensor([0, offline_dataset_size // 2]),
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
online_dataset, _ = make_new_buffer()
if online_dataset_size > 0:
online_dataset.add_data(
@@ -254,16 +239,9 @@ def test_compute_sampler_weights_trivial(
assert torch.allclose(weights, expected_weights)
def test_compute_sampler_weights_nontrivial_ratio():
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {
"from": torch.tensor([0, 2]),
"to": torch.tensor([2, 4]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
online_sampling_ratio = 0.8
@@ -275,16 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio():
)
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
# Create spoof offline dataset.
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {
"from": torch.tensor([0]),
"to": torch.tensor([4]),
}
# Create spoof online datset.
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
weights = compute_sampler_weights(
@@ -295,18 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
)
def test_compute_sampler_weights_drop_n_last_frames():
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
"""Note: test copied from test_sampler."""
data_dict = {
"timestamp": [0, 0.1],
"index": [0, 1],
"episode_index": [0, 0],
"frame_index": [0, 1],
}
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
online_dataset, _ = make_new_buffer()
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))

View File

@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
# TODO(aliberts): refactor using lerobot/__init__.py variables
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name,policy_name,extra_overrides",
[
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
# Check that we can make the policy object.
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
# Check that the policy follows the required protocol.
assert isinstance(
policy, Policy
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
env.step(action)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
def test_act_backbone_lr():
"""
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
assert cfg.training.lr_backbone == 0.001
dataset = make_dataset(cfg)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
assert len(optimizer.param_groups) == 2
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
@@ -351,6 +352,7 @@ def test_normalize(insert_temporal_dim):
unnormalize(output_batch)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"env_name, policy_name, extra_overrides, file_name_extra",
[
@@ -381,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
include a report on what changed and how that affected the outputs.
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
add the policies you want to update the test artifacts for.
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
should be updated.
4. Check that this test now passes.
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.

View File

@@ -5,7 +5,7 @@ we skip them for now in our CI.
Example to run backward compatiblity tests locally:
```
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
```
"""
@@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
)
@pytest.mark.skip("TODO after v2 migration / removing hydra")
@pytest.mark.parametrize(
"required_packages, raw_format, repo_id, make_test_data",
[
@@ -329,7 +330,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
],
)
@pytest.mark.skip(
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
)
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
_, dataset_id = repo_id.split("/")

View File

@@ -15,9 +15,9 @@
# limitations under the License.
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.sampler import EpisodeAwareSampler
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
)

View File

@@ -7,10 +7,9 @@ import pytest
import torch
from datasets import Dataset
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
from lerobot.common.datasets.utils import (
calculate_episode_data_index,
hf_transform_to_torch,
reset_episode_index,
)
from lerobot.common.utils.utils import (
get_global_random_state,
@@ -73,20 +72,6 @@ def test_calculate_episode_data_index():
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
def test_reset_episode_index():
dataset = Dataset.from_dict(
{
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
"index": [0, 1, 2, 3, 4, 5],
"episode_index": [10, 10, 11, 12, 12, 12],
},
)
dataset.set_transform(hf_transform_to_torch)
correct_episode_index = [0, 0, 1, 2, 2, 2]
dataset = reset_episode_index(dataset)
assert dataset["episode_index"] == correct_episode_index
def test_init_hydra_config_empty():
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
with open(test_file, "w") as f:

View File

@@ -13,25 +13,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset import visualize_dataset
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
def test_visualize_local_dataset(tmpdir, repo_id, root):
@pytest.mark.skip("TODO: add dummy videos")
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
rrd_path = visualize_dataset(
repo_id,
dataset,
episode_index=0,
batch_size=32,
save=True,
output_dir=tmpdir,
root=root,
output_dir=output_dir,
)
assert rrd_path.exists()

View File

@@ -14,23 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pytest
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
@pytest.mark.parametrize(
"repo_id",
["lerobot/pusht"],
)
def test_visualize_dataset_html(tmpdir, repo_id):
tmpdir = Path(tmpdir)
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
root = tmp_path / "dataset"
output_dir = tmp_path / "outputs"
dataset = lerobot_dataset_factory(root=root)
visualize_dataset_html(
repo_id,
dataset,
episodes=[0],
output_dir=tmpdir,
output_dir=output_dir,
serve=False,
)
assert (tmpdir / "static" / "episode_0.csv").exists()
assert (output_dir / "static" / "episode_0.csv").exists()