forked from tangger/lerobot
@@ -23,6 +23,13 @@ from lerobot import available_cameras, available_motors, available_robots
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
from tests.utils import DEVICE, ROBOT_CONFIG_PATH_TEMPLATE, make_camera, make_motors_bus
|
||||
|
||||
# Import fixture modules as plugins
|
||||
pytest_plugins = [
|
||||
"tests.fixtures.dataset_factories",
|
||||
"tests.fixtures.files",
|
||||
"tests.fixtures.hub",
|
||||
]
|
||||
|
||||
|
||||
def pytest_collection_finish():
|
||||
print(f"\nTesting with {DEVICE=}")
|
||||
|
||||
29
tests/fixtures/constants.py
vendored
Normal file
29
tests/fixtures/constants.py
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME
|
||||
|
||||
LEROBOT_TEST_DIR = LEROBOT_HOME / "_testing"
|
||||
DUMMY_REPO_ID = "dummy/repo"
|
||||
DUMMY_ROBOT_TYPE = "dummy_robot"
|
||||
DUMMY_MOTOR_FEATURES = {
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
"state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"],
|
||||
},
|
||||
}
|
||||
DUMMY_CAMERA_FEATURES = {
|
||||
"laptop": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
"phone": {"shape": (480, 640, 3), "names": ["height", "width", "channels"], "info": None},
|
||||
}
|
||||
DEFAULT_FPS = 30
|
||||
DUMMY_VIDEO_INFO = {
|
||||
"video.fps": DEFAULT_FPS,
|
||||
"video.codec": "av1",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": False,
|
||||
"has_audio": False,
|
||||
}
|
||||
396
tests/fixtures/dataset_factories.py
vendored
Normal file
396
tests/fixtures/dataset_factories.py
vendored
Normal file
@@ -0,0 +1,396 @@
|
||||
import random
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.common.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_FEATURES,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.constants import (
|
||||
DEFAULT_FPS,
|
||||
DUMMY_CAMERA_FEATURES,
|
||||
DUMMY_MOTOR_FEATURES,
|
||||
DUMMY_REPO_ID,
|
||||
DUMMY_ROBOT_TYPE,
|
||||
DUMMY_VIDEO_INFO,
|
||||
)
|
||||
|
||||
|
||||
def get_task_index(task_dicts: dict, task: str) -> int:
|
||||
tasks = {d["task_index"]: d["task"] for d in task_dicts}
|
||||
task_to_task_index = {task: task_idx for task_idx, task in tasks.items()}
|
||||
return task_to_task_index[task]
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_tensor_factory():
|
||||
def _create_img_tensor(height=100, width=100, channels=3, dtype=torch.float32) -> torch.Tensor:
|
||||
return torch.rand((channels, height, width), dtype=dtype)
|
||||
|
||||
return _create_img_tensor
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
return img_array
|
||||
|
||||
return _create_img_array
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_factory(img_array_factory):
|
||||
def _create_img(height=100, width=100) -> PIL.Image.Image:
|
||||
img_array = img_array_factory(height=height, width=width)
|
||||
return PIL.Image.fromarray(img_array)
|
||||
|
||||
return _create_img
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
key: {"dtype": "video", **ft, **DUMMY_VIDEO_INFO} for key, ft in camera_features.items()
|
||||
}
|
||||
else:
|
||||
camera_ft = {key: {"dtype": "image", **ft} for key, ft in camera_features.items()}
|
||||
return {
|
||||
**motor_features,
|
||||
**camera_ft,
|
||||
**DEFAULT_FEATURES,
|
||||
}
|
||||
|
||||
return _create_features
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_factory(features_factory):
|
||||
def _create_info(
|
||||
codebase_version: str = CODEBASE_VERSION,
|
||||
fps: int = DEFAULT_FPS,
|
||||
robot_type: str = DUMMY_ROBOT_TYPE,
|
||||
total_episodes: int = 0,
|
||||
total_frames: int = 0,
|
||||
total_tasks: int = 0,
|
||||
total_videos: int = 0,
|
||||
total_chunks: int = 0,
|
||||
chunks_size: int = DEFAULT_CHUNK_SIZE,
|
||||
data_path: str = DEFAULT_PARQUET_PATH,
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
"codebase_version": codebase_version,
|
||||
"robot_type": robot_type,
|
||||
"total_episodes": total_episodes,
|
||||
"total_frames": total_frames,
|
||||
"total_tasks": total_tasks,
|
||||
"total_videos": total_videos,
|
||||
"total_chunks": total_chunks,
|
||||
"chunks_size": chunks_size,
|
||||
"fps": fps,
|
||||
"splits": {},
|
||||
"data_path": data_path,
|
||||
"video_path": video_path if use_videos else None,
|
||||
"features": features,
|
||||
}
|
||||
|
||||
return _create_info
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_factory():
|
||||
def _create_stats(
|
||||
features: dict[str] | None = None,
|
||||
) -> dict:
|
||||
stats = {}
|
||||
for key, ft in features.items():
|
||||
shape = ft["shape"]
|
||||
dtype = ft["dtype"]
|
||||
if dtype in ["image", "video"]:
|
||||
stats[key] = {
|
||||
"max": np.full((3, 1, 1), 1, dtype=np.float32).tolist(),
|
||||
"mean": np.full((3, 1, 1), 0.5, dtype=np.float32).tolist(),
|
||||
"min": np.full((3, 1, 1), 0, dtype=np.float32).tolist(),
|
||||
"std": np.full((3, 1, 1), 0.25, dtype=np.float32).tolist(),
|
||||
}
|
||||
else:
|
||||
stats[key] = {
|
||||
"max": np.full(shape, 1, dtype=dtype).tolist(),
|
||||
"mean": np.full(shape, 0.5, dtype=dtype).tolist(),
|
||||
"min": np.full(shape, 0, dtype=dtype).tolist(),
|
||||
"std": np.full(shape, 0.25, dtype=dtype).tolist(),
|
||||
}
|
||||
return stats
|
||||
|
||||
return _create_stats
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> int:
|
||||
tasks_list = []
|
||||
for i in range(total_tasks):
|
||||
task_dict = {"task_index": i, "task": f"Perform action {i}."}
|
||||
tasks_list.append(task_dict)
|
||||
return tasks_list
|
||||
|
||||
return _create_tasks
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episodes_factory(tasks_factory):
|
||||
def _create_episodes(
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 400,
|
||||
tasks: dict | None = None,
|
||||
multi_task: bool = False,
|
||||
):
|
||||
if total_episodes <= 0 or total_frames <= 0:
|
||||
raise ValueError("num_episodes and total_length must be positive integers.")
|
||||
if total_frames < total_episodes:
|
||||
raise ValueError("total_length must be greater than or equal to num_episodes.")
|
||||
|
||||
if not tasks:
|
||||
min_tasks = 2 if multi_task else 1
|
||||
total_tasks = random.randint(min_tasks, total_episodes)
|
||||
tasks = tasks_factory(total_tasks)
|
||||
|
||||
if total_episodes < len(tasks) and not multi_task:
|
||||
raise ValueError("The number of tasks should be less than the number of episodes.")
|
||||
|
||||
# Generate random lengths that sum up to total_length
|
||||
lengths = np.random.multinomial(total_frames, [1 / total_episodes] * total_episodes).tolist()
|
||||
|
||||
tasks_list = [task_dict["task"] for task_dict in tasks]
|
||||
num_tasks_available = len(tasks_list)
|
||||
|
||||
episodes_list = []
|
||||
remaining_tasks = tasks_list.copy()
|
||||
for ep_idx in range(total_episodes):
|
||||
num_tasks_in_episode = random.randint(1, min(3, num_tasks_available)) if multi_task else 1
|
||||
tasks_to_sample = remaining_tasks if remaining_tasks else tasks_list
|
||||
episode_tasks = random.sample(tasks_to_sample, min(num_tasks_in_episode, len(tasks_to_sample)))
|
||||
if remaining_tasks:
|
||||
for task in episode_tasks:
|
||||
remaining_tasks.remove(task)
|
||||
|
||||
episodes_list.append(
|
||||
{
|
||||
"episode_index": ep_idx,
|
||||
"tasks": episode_tasks,
|
||||
"length": lengths[ep_idx],
|
||||
}
|
||||
)
|
||||
|
||||
return episodes_list
|
||||
|
||||
return _create_episodes
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_array_factory):
|
||||
def _create_hf_dataset(
|
||||
features: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
fps: int = DEFAULT_FPS,
|
||||
) -> datasets.Dataset:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
if not features:
|
||||
features = features_factory()
|
||||
|
||||
timestamp_col = np.array([], dtype=np.float32)
|
||||
frame_index_col = np.array([], dtype=np.int64)
|
||||
episode_index_col = np.array([], dtype=np.int64)
|
||||
task_index = np.array([], dtype=np.int64)
|
||||
for ep_dict in episodes:
|
||||
timestamp_col = np.concatenate((timestamp_col, np.arange(ep_dict["length"]) / fps))
|
||||
frame_index_col = np.concatenate((frame_index_col, np.arange(ep_dict["length"], dtype=int)))
|
||||
episode_index_col = np.concatenate(
|
||||
(episode_index_col, np.full(ep_dict["length"], ep_dict["episode_index"], dtype=int))
|
||||
)
|
||||
ep_task_index = get_task_index(tasks, ep_dict["tasks"][0])
|
||||
task_index = np.concatenate((task_index, np.full(ep_dict["length"], ep_task_index, dtype=int)))
|
||||
|
||||
index_col = np.arange(len(episode_index_col))
|
||||
|
||||
robot_cols = {}
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(height=ft["shapes"][1], width=ft["shapes"][0])
|
||||
for _ in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
|
||||
hf_features = get_hf_features_from_features(features)
|
||||
dataset = datasets.Dataset.from_dict(
|
||||
{
|
||||
**robot_cols,
|
||||
"timestamp": timestamp_col,
|
||||
"frame_index": frame_index_col,
|
||||
"episode_index": episode_index_col,
|
||||
"index": index_col,
|
||||
"task_index": task_index,
|
||||
},
|
||||
features=hf_features,
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
return dataset
|
||||
|
||||
return _create_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_metadata_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
mock_snapshot_download_factory,
|
||||
):
|
||||
def _create_lerobot_dataset_metadata(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
local_files_only: bool = False,
|
||||
) -> LeRobotDatasetMetadata:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.get_hub_safe_version"
|
||||
) as mock_get_hub_safe_version_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_hub_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDatasetMetadata(repo_id=repo_id, root=root, local_files_only=local_files_only)
|
||||
|
||||
return _create_lerobot_dataset_metadata
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def lerobot_dataset_factory(
|
||||
info_factory,
|
||||
stats_factory,
|
||||
tasks_factory,
|
||||
episodes_factory,
|
||||
hf_dataset_factory,
|
||||
mock_snapshot_download_factory,
|
||||
lerobot_dataset_metadata_factory,
|
||||
):
|
||||
def _create_lerobot_dataset(
|
||||
root: Path,
|
||||
repo_id: str = DUMMY_REPO_ID,
|
||||
total_episodes: int = 3,
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episode_dicts: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
**kwargs,
|
||||
) -> LeRobotDataset:
|
||||
if not info:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
)
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episode_dicts:
|
||||
episode_dicts = episodes_factory(
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episode_dicts, fps=info["fps"])
|
||||
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
hf_dataset=hf_dataset,
|
||||
)
|
||||
mock_metadata = lerobot_dataset_metadata_factory(
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
info=info,
|
||||
stats=stats,
|
||||
tasks=tasks,
|
||||
episodes=episode_dicts,
|
||||
local_files_only=kwargs.get("local_files_only", False),
|
||||
)
|
||||
with (
|
||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||
patch(
|
||||
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||
) as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_metadata_patch.return_value = mock_metadata
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
return LeRobotDataset(repo_id=repo_id, root=root, **kwargs)
|
||||
|
||||
return _create_lerobot_dataset
|
||||
114
tests/fixtures/files.py
vendored
Normal file
114
tests/fixtures/files.py
vendored
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import jsonlines
|
||||
import pyarrow.compute as pc
|
||||
import pyarrow.parquet as pq
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def info_path(info_factory):
|
||||
def _create_info_json_file(dir: Path, info: dict | None = None) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
fpath = dir / INFO_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(info, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
return _create_info_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def stats_path(stats_factory):
|
||||
def _create_stats_json_file(dir: Path, stats: dict | None = None) -> Path:
|
||||
if not stats:
|
||||
stats = stats_factory()
|
||||
fpath = dir / STATS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(stats, f, indent=4, ensure_ascii=False)
|
||||
return fpath
|
||||
|
||||
return _create_stats_json_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tasks_path(tasks_factory):
|
||||
def _create_tasks_jsonl_file(dir: Path, tasks: list | None = None) -> Path:
|
||||
if not tasks:
|
||||
tasks = tasks_factory()
|
||||
fpath = dir / TASKS_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(tasks)
|
||||
return fpath
|
||||
|
||||
return _create_tasks_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def episode_path(episodes_factory):
|
||||
def _create_episodes_jsonl_file(dir: Path, episodes: list | None = None) -> Path:
|
||||
if not episodes:
|
||||
episodes = episodes_factory()
|
||||
fpath = dir / EPISODES_PATH
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with jsonlines.open(fpath, "w") as writer:
|
||||
writer.write_all(episodes)
|
||||
return fpath
|
||||
|
||||
return _create_episodes_jsonl_file
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def single_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_single_episode_parquet(
|
||||
dir: Path, ep_idx: int = 0, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return fpath
|
||||
|
||||
return _create_single_episode_parquet
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def multi_episode_parquet_path(hf_dataset_factory, info_factory):
|
||||
def _create_multi_episode_parquet(
|
||||
dir: Path, hf_dataset: datasets.Dataset | None = None, info: dict | None = None
|
||||
) -> Path:
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory()
|
||||
|
||||
data_path = info["data_path"]
|
||||
chunks_size = info["chunks_size"]
|
||||
total_episodes = info["total_episodes"]
|
||||
for ep_idx in range(total_episodes):
|
||||
ep_chunk = ep_idx // chunks_size
|
||||
fpath = dir / data_path.format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
table = hf_dataset.data.table
|
||||
ep_table = table.filter(pc.equal(table["episode_index"], ep_idx))
|
||||
pq.write_table(ep_table, fpath)
|
||||
return dir / "data"
|
||||
|
||||
return _create_multi_episode_parquet
|
||||
105
tests/fixtures/hub.py
vendored
Normal file
105
tests/fixtures/hub.py
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import pytest
|
||||
from huggingface_hub.utils import filter_repo_objects
|
||||
|
||||
from lerobot.common.datasets.utils import EPISODES_PATH, INFO_PATH, STATS_PATH, TASKS_PATH
|
||||
from tests.fixtures.constants import LEROBOT_TEST_DIR
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def mock_snapshot_download_factory(
|
||||
info_factory,
|
||||
info_path,
|
||||
stats_factory,
|
||||
stats_path,
|
||||
tasks_factory,
|
||||
tasks_path,
|
||||
episodes_factory,
|
||||
episode_path,
|
||||
single_episode_parquet_path,
|
||||
hf_dataset_factory,
|
||||
):
|
||||
"""
|
||||
This factory allows to patch snapshot_download such that when called, it will create expected files rather
|
||||
than making calls to the hub api. Its design allows to pass explicitly files which you want to be created.
|
||||
"""
|
||||
|
||||
def _mock_snapshot_download_func(
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: list[dict] | None = None,
|
||||
episodes: list[dict] | None = None,
|
||||
hf_dataset: datasets.Dataset | None = None,
|
||||
):
|
||||
if not info:
|
||||
info = info_factory()
|
||||
if not stats:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if not tasks:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if not episodes:
|
||||
episodes = episodes_factory(
|
||||
total_episodes=info["total_episodes"], total_frames=info["total_frames"], tasks=tasks
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes, fps=info["fps"])
|
||||
|
||||
def _extract_episode_index_from_path(fpath: str) -> int:
|
||||
path = Path(fpath)
|
||||
if path.suffix == ".parquet" and path.stem.startswith("episode_"):
|
||||
episode_index = int(path.stem[len("episode_") :]) # 'episode_000000' -> 0
|
||||
return episode_index
|
||||
else:
|
||||
return None
|
||||
|
||||
def _mock_snapshot_download(
|
||||
repo_id: str,
|
||||
local_dir: str | Path | None = None,
|
||||
allow_patterns: str | list[str] | None = None,
|
||||
ignore_patterns: str | list[str] | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not local_dir:
|
||||
local_dir = LEROBOT_TEST_DIR
|
||||
|
||||
# List all possible files
|
||||
all_files = []
|
||||
meta_files = [INFO_PATH, STATS_PATH, TASKS_PATH, EPISODES_PATH]
|
||||
all_files.extend(meta_files)
|
||||
|
||||
data_files = []
|
||||
for episode_dict in episodes:
|
||||
ep_idx = episode_dict["episode_index"]
|
||||
ep_chunk = ep_idx // info["chunks_size"]
|
||||
data_path = info["data_path"].format(episode_chunk=ep_chunk, episode_index=ep_idx)
|
||||
data_files.append(data_path)
|
||||
all_files.extend(data_files)
|
||||
|
||||
allowed_files = filter_repo_objects(
|
||||
all_files, allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
|
||||
# Create allowed files
|
||||
for rel_path in allowed_files:
|
||||
if rel_path.startswith("data/"):
|
||||
episode_index = _extract_episode_index_from_path(rel_path)
|
||||
if episode_index is not None:
|
||||
_ = single_episode_parquet_path(local_dir, episode_index, hf_dataset, info)
|
||||
if rel_path == INFO_PATH:
|
||||
_ = info_path(local_dir, info)
|
||||
elif rel_path == STATS_PATH:
|
||||
_ = stats_path(local_dir, stats)
|
||||
elif rel_path == TASKS_PATH:
|
||||
_ = tasks_path(local_dir, tasks)
|
||||
elif rel_path == EPISODES_PATH:
|
||||
_ = episode_path(local_dir, episodes)
|
||||
else:
|
||||
pass
|
||||
return str(local_dir)
|
||||
|
||||
return _mock_snapshot_download
|
||||
|
||||
return _mock_snapshot_download_func
|
||||
@@ -76,7 +76,7 @@ def main():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID, image_transforms=None)
|
||||
output_dir = Path(ARTIFACT_DIR)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
original_frame = dataset[0][dataset.camera_keys[0]]
|
||||
original_frame = dataset[0][dataset.meta.camera_keys[0]]
|
||||
|
||||
save_single_transforms(original_frame, output_dir)
|
||||
save_default_config_transform(original_frame, output_dir)
|
||||
|
||||
@@ -38,7 +38,7 @@ def get_policy_stats(env_name, policy_name, extra_overrides):
|
||||
)
|
||||
set_global_seed(1337)
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(cfg, dataset_stats=dataset.meta.stats)
|
||||
policy.train()
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
|
||||
@@ -29,7 +29,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.datasets.populate_dataset import add_frame, init_dataset
|
||||
from lerobot.common.logger import Logger
|
||||
from lerobot.common.policies.factory import make_policy
|
||||
from lerobot.common.utils.utils import init_hydra_config
|
||||
@@ -93,8 +92,9 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
mock_calibration_dir(calibration_dir)
|
||||
overrides.append(f"calibration_dir={calibration_dir}")
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
record(
|
||||
@@ -102,6 +102,7 @@ def test_record_without_cameras(tmpdir, request, robot_type, mock):
|
||||
fps=30,
|
||||
root=root,
|
||||
repo_id=repo_id,
|
||||
single_task=single_task,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
@@ -132,17 +133,18 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
env_name = "koch_real"
|
||||
policy_name = "act_koch_real"
|
||||
|
||||
root = tmpdir / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
root = tmpdir / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=1,
|
||||
warmup_time_s=0.5,
|
||||
episode_time_s=1,
|
||||
reset_time_s=1,
|
||||
num_episodes=2,
|
||||
@@ -153,7 +155,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
)
|
||||
assert dataset.num_episodes == 2
|
||||
assert dataset.meta.total_episodes == 2
|
||||
assert len(dataset) == 2
|
||||
|
||||
replay(robot, episode=0, fps=1, root=root, repo_id=repo_id, play_sounds=False)
|
||||
@@ -191,7 +193,7 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
out_dir = tmpdir / "logger"
|
||||
logger = Logger(cfg, out_dir, wandb_job_name="debug")
|
||||
@@ -225,10 +227,14 @@ def test_record_and_replay_and_policy(tmpdir, request, robot_type, mock):
|
||||
else:
|
||||
num_image_writer_processes = 0
|
||||
|
||||
record(
|
||||
eval_repo_id = "lerobot/eval_debug"
|
||||
eval_root = tmpdir / "data" / eval_repo_id
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
eval_root,
|
||||
eval_repo_id,
|
||||
single_task,
|
||||
pretrained_policy_name_or_path,
|
||||
warmup_time_s=1,
|
||||
episode_time_s=1,
|
||||
@@ -265,51 +271,36 @@ def test_resume_record(tmpdir, request, robot_type, mock):
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=1,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
record_kwargs = {
|
||||
"robot": robot,
|
||||
"root": root,
|
||||
"repo_id": repo_id,
|
||||
"single_task": single_task,
|
||||
"fps": 1,
|
||||
"warmup_time_s": 0,
|
||||
"episode_time_s": 1,
|
||||
"push_to_hub": False,
|
||||
"video": False,
|
||||
"display_cameras": False,
|
||||
"play_sounds": False,
|
||||
"run_compute_stats": False,
|
||||
"local_files_only": True,
|
||||
"num_episodes": 1,
|
||||
}
|
||||
|
||||
init_dataset_return_value = {}
|
||||
dataset = record(**record_kwargs)
|
||||
assert len(dataset) == 1, f"`dataset` should contain 1 frame, not {len(dataset)}"
|
||||
|
||||
def wrapped_init_dataset(*args, **kwargs):
|
||||
nonlocal init_dataset_return_value
|
||||
init_dataset_return_value = init_dataset(*args, **kwargs)
|
||||
return init_dataset_return_value
|
||||
with pytest.raises(FileExistsError):
|
||||
# Dataset already exists, but resume=False by default
|
||||
record(**record_kwargs)
|
||||
|
||||
with patch("lerobot.scripts.control_robot.init_dataset", wraps=wrapped_init_dataset):
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
num_episodes=2,
|
||||
push_to_hub=False,
|
||||
video=False,
|
||||
display_cameras=False,
|
||||
play_sounds=False,
|
||||
run_compute_stats=False,
|
||||
)
|
||||
assert len(dataset) == 2, "`dataset` should contain only 1 frame"
|
||||
assert (
|
||||
init_dataset_return_value["num_episodes"] == 2
|
||||
), "`init_dataset` should load the previous episode"
|
||||
dataset = record(**record_kwargs, resume=True)
|
||||
assert len(dataset) == 2, f"`dataset` should contain 2 frames, not {len(dataset)}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("robot_type, mock", [("koch", True)])
|
||||
@@ -328,23 +319,22 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = True
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -358,7 +348,6 @@ def test_record_with_event_rerecord_episode(tmpdir, request, robot_type, mock):
|
||||
|
||||
assert not mock_events["rerecord_episode"], "`rerecord_episode` wasn't properly reset to False"
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 2, "`add_frame` should have been called 2 times"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@@ -378,23 +367,22 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = False
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
fps=2,
|
||||
root=root,
|
||||
single_task=single_task,
|
||||
repo_id=repo_id,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -407,7 +395,6 @@ def test_record_with_event_exit_early(tmpdir, request, robot_type, mock):
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
|
||||
@@ -429,23 +416,22 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
overrides = []
|
||||
|
||||
robot = make_robot(robot_type, overrides=overrides, mock=mock)
|
||||
with (
|
||||
patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener,
|
||||
patch("lerobot.common.robot_devices.control_utils.add_frame", wraps=add_frame) as mock_add_frame,
|
||||
):
|
||||
with patch("lerobot.scripts.control_robot.init_keyboard_listener") as mock_listener:
|
||||
mock_events = {}
|
||||
mock_events["exit_early"] = True
|
||||
mock_events["rerecord_episode"] = False
|
||||
mock_events["stop_recording"] = True
|
||||
mock_listener.return_value = (None, mock_events)
|
||||
|
||||
root = Path(tmpdir) / "data"
|
||||
repo_id = "lerobot/debug"
|
||||
root = Path(tmpdir) / "data" / repo_id
|
||||
single_task = "Do something."
|
||||
|
||||
dataset = record(
|
||||
robot,
|
||||
root,
|
||||
repo_id,
|
||||
single_task=single_task,
|
||||
fps=1,
|
||||
warmup_time_s=0,
|
||||
episode_time_s=1,
|
||||
@@ -459,5 +445,4 @@ def test_record_with_event_stop_recording(tmpdir, request, robot_type, mock, num
|
||||
)
|
||||
|
||||
assert not mock_events["exit_early"], "`exit_early` wasn't properly reset to False"
|
||||
assert mock_add_frame.call_count == 1, "`add_frame` should have been called 1 time"
|
||||
assert len(dataset) == 1, "`dataset` should contain only 1 frame"
|
||||
|
||||
@@ -33,18 +33,72 @@ from lerobot.common.datasets.compute_stats import (
|
||||
get_stats_einops_patterns,
|
||||
)
|
||||
from lerobot.common.datasets.factory import make_dataset
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, MultiLeRobotDataset
|
||||
from lerobot.common.datasets.lerobot_dataset import (
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
)
|
||||
from lerobot.common.datasets.utils import (
|
||||
create_branch,
|
||||
flatten_dict,
|
||||
hf_transform_to_torch,
|
||||
load_previous_and_future_frames,
|
||||
unflatten_dict,
|
||||
)
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.utils import DEFAULT_CONFIG_PATH, DEVICE, make_robot
|
||||
|
||||
|
||||
def test_same_attributes_defined(lerobot_dataset_factory, tmp_path):
|
||||
"""
|
||||
Instantiate a LeRobotDataset both ways with '__init__()' and 'create()' and verify that instantiated
|
||||
objects have the same sets of attributes defined.
|
||||
"""
|
||||
# Instantiate both ways
|
||||
robot = make_robot("koch", mock=True)
|
||||
root_create = tmp_path / "create"
|
||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, robot=robot, root=root_create)
|
||||
|
||||
root_init = tmp_path / "init"
|
||||
dataset_init = lerobot_dataset_factory(root=root_init)
|
||||
|
||||
# Access the '_hub_version' cached_property in both instances to force its creation
|
||||
_ = dataset_init.meta._hub_version
|
||||
_ = dataset_create.meta._hub_version
|
||||
|
||||
init_attr = set(vars(dataset_init).keys())
|
||||
create_attr = set(vars(dataset_create).keys())
|
||||
|
||||
assert init_attr == create_attr
|
||||
|
||||
|
||||
def test_dataset_initialization(lerobot_dataset_factory, tmp_path):
|
||||
kwargs = {
|
||||
"repo_id": DUMMY_REPO_ID,
|
||||
"total_episodes": 10,
|
||||
"total_frames": 400,
|
||||
"episodes": [2, 5, 6],
|
||||
}
|
||||
dataset = lerobot_dataset_factory(root=tmp_path, **kwargs)
|
||||
|
||||
assert dataset.repo_id == kwargs["repo_id"]
|
||||
assert dataset.meta.total_episodes == kwargs["total_episodes"]
|
||||
assert dataset.meta.total_frames == kwargs["total_frames"]
|
||||
assert dataset.episodes == kwargs["episodes"]
|
||||
assert dataset.num_episodes == len(kwargs["episodes"])
|
||||
assert dataset.num_frames == len(dataset)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
# - [ ] test add_frame
|
||||
# - [ ] test add_episode
|
||||
# - [ ] test consolidate
|
||||
# - [ ] test push_to_hub
|
||||
# - [ ] test smaller methods
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, repo_id, policy_name",
|
||||
lerobot.env_dataset_policy_triplets
|
||||
@@ -67,7 +121,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
)
|
||||
dataset = make_dataset(cfg)
|
||||
delta_timestamps = dataset.delta_timestamps
|
||||
camera_keys = dataset.camera_keys
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
item = dataset[0]
|
||||
|
||||
@@ -117,6 +171,7 @@ def test_factory(env_name, repo_id, policy_name):
|
||||
|
||||
|
||||
# TODO(alexander-soare): If you're hunting for savings on testing time, this takes about 5 seconds.
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_multilerobotdataset_frames():
|
||||
"""Check that all dataset frames are incorporated."""
|
||||
# Note: use the image variants of the dataset to make the test approx 3x faster.
|
||||
@@ -130,7 +185,7 @@ def test_multilerobotdataset_frames():
|
||||
sub_datasets = [LeRobotDataset(repo_id) for repo_id in repo_ids]
|
||||
dataset = MultiLeRobotDataset(repo_ids)
|
||||
assert len(dataset) == sum(len(d) for d in sub_datasets)
|
||||
assert dataset.num_samples == sum(d.num_samples for d in sub_datasets)
|
||||
assert dataset.num_frames == sum(d.num_frames for d in sub_datasets)
|
||||
assert dataset.num_episodes == sum(d.num_episodes for d in sub_datasets)
|
||||
|
||||
# Run through all items of the LeRobotDatasets in parallel with the items of the MultiLerobotDataset and
|
||||
@@ -149,6 +204,8 @@ def test_multilerobotdataset_frames():
|
||||
assert torch.equal(sub_dataset_item[k], dataset_item[k])
|
||||
|
||||
|
||||
# TODO(aliberts, rcadene): Refactor and move this to a tests/test_compute_stats.py
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_compute_stats_on_xarm():
|
||||
"""Check that the statistics are computed correctly according to the stats_patterns property.
|
||||
|
||||
@@ -197,7 +254,7 @@ def test_compute_stats_on_xarm():
|
||||
assert torch.allclose(computed_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
# load stats used during training which are expected to match the ones returned by computed_stats
|
||||
loaded_stats = dataset.stats # noqa: F841
|
||||
loaded_stats = dataset.meta.stats # noqa: F841
|
||||
|
||||
# TODO(rcadene): we can't test this because expected_stats is computed on a subset
|
||||
# # test loaded stats match expected stats
|
||||
@@ -208,72 +265,7 @@ def test_compute_stats_on_xarm():
|
||||
# assert torch.allclose(loaded_stats[k]["max"], expected_stats[k]["max"])
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_within_tolerance():
|
||||
hf_dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.139]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 2, 3])), "Data does not match expected values"
|
||||
assert not is_pad.any(), "Unexpected padding detected"
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_inside_episode_range():
|
||||
hf_dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.2, 0, 0.141]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
with pytest.raises(AssertionError):
|
||||
load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
|
||||
|
||||
def test_load_previous_and_future_frames_outside_tolerance_outside_episode_range():
|
||||
hf_dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5],
|
||||
"index": [0, 1, 2, 3, 4],
|
||||
"episode_index": [0, 0, 0, 0, 0],
|
||||
}
|
||||
)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([5]),
|
||||
}
|
||||
delta_timestamps = {"index": [-0.3, -0.24, 0, 0.26, 0.3]}
|
||||
tol = 0.04
|
||||
item = hf_dataset[2]
|
||||
item = load_previous_and_future_frames(item, hf_dataset, episode_data_index, delta_timestamps, tol)
|
||||
data, is_pad = item["index"], item["index_is_pad"]
|
||||
assert torch.equal(data, torch.tensor([0, 0, 2, 4, 4])), "Data does not match expected values"
|
||||
assert torch.equal(
|
||||
is_pad, torch.tensor([True, False, False, True, True])
|
||||
), "Padding does not match expected values"
|
||||
|
||||
|
||||
# TODO(aliberts): Move to more appropriate location
|
||||
def test_flatten_unflatten_dict():
|
||||
d = {
|
||||
"obs": {
|
||||
@@ -297,6 +289,7 @@ def test_flatten_unflatten_dict():
|
||||
assert json.dumps(original_d, sort_keys=True) == json.dumps(d, sort_keys=True), f"{original_d} != {d}"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
[
|
||||
@@ -368,6 +361,7 @@ def test_backward_compatibility(repo_id):
|
||||
# load_and_compare(i - 1)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_aggregate_stats():
|
||||
"""Makes 3 basic datasets and checks that aggregate stats are computed correctly."""
|
||||
with seeded_context(0):
|
||||
|
||||
256
tests/test_delta_timestamps.py
Normal file
256
tests/test_delta_timestamps.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
check_delta_timestamps,
|
||||
check_timestamps_sync,
|
||||
get_delta_indices,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from tests.fixtures.constants import DUMMY_MOTOR_FEATURES
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def synced_hf_dataset_factory(hf_dataset_factory):
|
||||
def _create_synced_hf_dataset(fps: int = 30) -> Dataset:
|
||||
return hf_dataset_factory(fps=fps)
|
||||
|
||||
return _create_synced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def unsynced_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_unsynced_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just outside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 1.1))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
|
||||
return _create_unsynced_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_hf_dataset_factory(synced_hf_dataset_factory):
|
||||
def _create_slightly_off_hf_dataset(fps: int = 30, tolerance_s: float = 1e-4) -> Dataset:
|
||||
hf_dataset = synced_hf_dataset_factory(fps=fps)
|
||||
features = hf_dataset.features
|
||||
df = hf_dataset.to_pandas()
|
||||
dtype = df["timestamp"].dtype # This is to avoid pandas type warning
|
||||
# Modify a single timestamp just inside tolerance
|
||||
df.at[30, "timestamp"] = dtype.type(df.at[30, "timestamp"] + (tolerance_s * 0.9))
|
||||
unsynced_hf_dataset = Dataset.from_pandas(df, features=features)
|
||||
unsynced_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return unsynced_hf_dataset
|
||||
|
||||
return _create_slightly_off_hf_dataset
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_delta_timestamps_factory():
|
||||
def _create_valid_delta_timestamps(fps: int = 30, keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
delta_timestamps = {key: [i * (1 / fps) for i in range(-10, 10)] for key in keys}
|
||||
return delta_timestamps
|
||||
|
||||
return _create_valid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def invalid_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_invalid_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just outside tolerance
|
||||
for key in keys:
|
||||
delta_timestamps[key][3] += tolerance_s * 1.1
|
||||
return delta_timestamps
|
||||
|
||||
return _create_invalid_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def slightly_off_delta_timestamps_factory(valid_delta_timestamps_factory):
|
||||
def _create_slightly_off_delta_timestamps(
|
||||
fps: int = 30, tolerance_s: float = 1e-4, keys: list = DUMMY_MOTOR_FEATURES
|
||||
) -> dict:
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps, keys)
|
||||
# Modify a single timestamp just inside tolerance
|
||||
for key in delta_timestamps:
|
||||
delta_timestamps[key][3] += tolerance_s * 0.9
|
||||
delta_timestamps[key][-3] += tolerance_s * 0.9
|
||||
return delta_timestamps
|
||||
|
||||
return _create_slightly_off_delta_timestamps
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def delta_indices(keys: list = DUMMY_MOTOR_FEATURES) -> dict:
|
||||
return {key: list(range(-10, 10)) for key in keys}
|
||||
|
||||
|
||||
def test_check_timestamps_sync_synced(synced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
synced_hf_dataset = synced_hf_dataset_factory(fps)
|
||||
episode_data_index = calculate_episode_data_index(synced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=synced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
with pytest.raises(ValueError):
|
||||
check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_timestamps_sync_unsynced_no_exception(unsynced_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
unsynced_hf_dataset = unsynced_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(unsynced_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=unsynced_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_timestamps_sync_slightly_off(slightly_off_hf_dataset_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_hf_dataset = slightly_off_hf_dataset_factory(fps, tolerance_s)
|
||||
episode_data_index = calculate_episode_data_index(slightly_off_hf_dataset)
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=slightly_off_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_timestamps_sync_single_timestamp():
|
||||
single_timestamp_hf_dataset = Dataset.from_dict({"timestamp": [0.0], "episode_index": [0]})
|
||||
single_timestamp_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {"to": torch.tensor([1]), "from": torch.tensor([0])}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=single_timestamp_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
# TODO(aliberts): Change behavior of hf_transform_to_torch so that it can work with empty dataset
|
||||
@pytest.mark.skip("TODO: fix")
|
||||
def test_check_timestamps_sync_empty_dataset():
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
empty_hf_dataset = Dataset.from_dict({"timestamp": [], "episode_index": []})
|
||||
empty_hf_dataset.set_transform(hf_transform_to_torch)
|
||||
episode_data_index = {
|
||||
"to": torch.tensor([], dtype=torch.int64),
|
||||
"from": torch.tensor([], dtype=torch.int64),
|
||||
}
|
||||
result = check_timestamps_sync(
|
||||
hf_dataset=empty_hf_dataset,
|
||||
episode_data_index=episode_data_index,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_valid(valid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
valid_delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=valid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_slightly_off(slightly_off_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
slightly_off_delta_timestamps = slightly_off_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=slightly_off_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
with pytest.raises(ValueError):
|
||||
check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
|
||||
|
||||
def test_check_delta_timestamps_invalid_no_exception(invalid_delta_timestamps_factory):
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
invalid_delta_timestamps = invalid_delta_timestamps_factory(fps, tolerance_s)
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=invalid_delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
raise_value_error=False,
|
||||
)
|
||||
assert result is False
|
||||
|
||||
|
||||
def test_check_delta_timestamps_empty():
|
||||
delta_timestamps = {}
|
||||
fps = 30
|
||||
tolerance_s = 1e-4
|
||||
result = check_delta_timestamps(
|
||||
delta_timestamps=delta_timestamps,
|
||||
fps=fps,
|
||||
tolerance_s=tolerance_s,
|
||||
)
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_delta_indices(valid_delta_timestamps_factory, delta_indices):
|
||||
fps = 30
|
||||
delta_timestamps = valid_delta_timestamps_factory(fps)
|
||||
expected_delta_indices = delta_indices
|
||||
actual_delta_indices = get_delta_indices(delta_timestamps, fps)
|
||||
assert expected_delta_indices == actual_delta_indices
|
||||
@@ -13,12 +13,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# TODO(aliberts): Mute logging for these tests
|
||||
|
||||
import io
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
from tests.utils import require_package
|
||||
|
||||
|
||||
@@ -29,6 +32,7 @@ def _find_and_replace(text: str, finds_and_replaces: list[tuple[str, str]]) -> s
|
||||
return text
|
||||
|
||||
|
||||
# TODO(aliberts): Remove usage of subprocess calls and patch code with fixtures
|
||||
def _run_script(path):
|
||||
subprocess.run([sys.executable, path], check=True)
|
||||
|
||||
@@ -38,12 +42,26 @@ def _read_file(path):
|
||||
return file.read()
|
||||
|
||||
|
||||
def test_example_1():
|
||||
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
|
||||
def test_example_1(tmp_path, lerobot_dataset_factory):
|
||||
_ = lerobot_dataset_factory(root=tmp_path, repo_id=DUMMY_REPO_ID)
|
||||
path = "examples/1_load_lerobot_dataset.py"
|
||||
_run_script(path)
|
||||
file_contents = _read_file(path)
|
||||
file_contents = _find_and_replace(
|
||||
file_contents,
|
||||
[
|
||||
('repo_id = "lerobot/pusht"', f'repo_id = "{DUMMY_REPO_ID}"'),
|
||||
(
|
||||
"LeRobotDataset(repo_id",
|
||||
f"LeRobotDataset(repo_id, root='{str(tmp_path)}', local_files_only=True",
|
||||
),
|
||||
],
|
||||
)
|
||||
exec(file_contents, {})
|
||||
assert Path("outputs/examples/1_load_lerobot_dataset/episode_0.mp4").exists()
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO Fix and remove subprocess / excec calls")
|
||||
@require_package("gym_pusht")
|
||||
def test_examples_basic2_basic3_advanced1():
|
||||
"""
|
||||
@@ -111,7 +129,8 @@ def test_examples_basic2_basic3_advanced1():
|
||||
'# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
'pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")',
|
||||
),
|
||||
('split=f"train[{first_val_frame_index}:]"', 'split="train[30:]"'),
|
||||
("train_episodes = episodes[:num_train_episodes]", "train_episodes = [0]"),
|
||||
("val_episodes = episodes[num_train_episodes:]", "val_episodes = [1]"),
|
||||
("num_workers=4", "num_workers=0"),
|
||||
('device = torch.device("cuda")', 'device = torch.device("cpu")'),
|
||||
("batch_size=64", "batch_size=1"),
|
||||
|
||||
@@ -15,15 +15,12 @@
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from PIL import Image
|
||||
from safetensors.torch import load_file
|
||||
from torchvision.transforms import v2
|
||||
from torchvision.transforms.v2 import functional as F # noqa: N812
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.transforms import RandomSubsetApply, SharpnessJitter, get_image_transforms
|
||||
from lerobot.common.utils.utils import init_hydra_config, seeded_context
|
||||
from lerobot.scripts.visualize_image_transforms import visualize_transforms
|
||||
@@ -33,21 +30,6 @@ ARTIFACT_DIR = Path("tests/data/save_image_transforms_to_safetensors")
|
||||
DATASET_REPO_ID = "lerobot/aloha_mobile_shrimp"
|
||||
|
||||
|
||||
def load_png_to_tensor(path: Path):
|
||||
return torch.from_numpy(np.array(Image.open(path).convert("RGB"))).permute(2, 0, 1)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img():
|
||||
dataset = LeRobotDataset(DATASET_REPO_ID)
|
||||
return dataset[0][dataset.camera_keys[0]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def img_random():
|
||||
return torch.rand(3, 480, 640)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def color_jitters():
|
||||
return [
|
||||
@@ -67,47 +49,54 @@ def default_transforms():
|
||||
return load_file(ARTIFACT_DIR / "default_transforms.safetensors")
|
||||
|
||||
|
||||
def test_get_image_transforms_no_transform(img):
|
||||
def test_get_image_transforms_no_transform(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_min_max=(0.5, 0.5), max_num_transforms=0)
|
||||
torch.testing.assert_close(tf_actual(img), img)
|
||||
torch.testing.assert_close(tf_actual(img_tensor), img_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_brightness(img, min_max):
|
||||
def test_get_image_transforms_brightness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(brightness_weight=1.0, brightness_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(brightness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_contrast(img, min_max):
|
||||
def test_get_image_transforms_contrast(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(contrast_weight=1.0, contrast_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(contrast=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_saturation(img, min_max):
|
||||
def test_get_image_transforms_saturation(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(saturation_weight=1.0, saturation_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(saturation=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(-0.25, -0.25), (0.25, 0.25)])
|
||||
def test_get_image_transforms_hue(img, min_max):
|
||||
def test_get_image_transforms_hue(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(hue_weight=1.0, hue_min_max=min_max)
|
||||
tf_expected = v2.ColorJitter(hue=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("min_max", [(0.5, 0.5), (2.0, 2.0)])
|
||||
def test_get_image_transforms_sharpness(img, min_max):
|
||||
def test_get_image_transforms_sharpness(img_tensor_factory, min_max):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(sharpness_weight=1.0, sharpness_min_max=min_max)
|
||||
tf_expected = SharpnessJitter(sharpness=min_max)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
def test_get_image_transforms_max_num_transforms(img):
|
||||
def test_get_image_transforms_max_num_transforms(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf_actual = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
@@ -125,12 +114,13 @@ def test_get_image_transforms_max_num_transforms(img):
|
||||
SharpnessJitter(sharpness=(0.5, 0.5)),
|
||||
]
|
||||
)
|
||||
torch.testing.assert_close(tf_actual(img), tf_expected(img))
|
||||
torch.testing.assert_close(tf_actual(img_tensor), tf_expected(img_tensor))
|
||||
|
||||
|
||||
@require_x86_64_kernel
|
||||
def test_get_image_transforms_random_order(img):
|
||||
def test_get_image_transforms_random_order(img_tensor_factory):
|
||||
out_imgs = []
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = get_image_transforms(
|
||||
brightness_min_max=(0.5, 0.5),
|
||||
contrast_min_max=(0.5, 0.5),
|
||||
@@ -141,13 +131,14 @@ def test_get_image_transforms_random_order(img):
|
||||
)
|
||||
with seeded_context(1337):
|
||||
for _ in range(10):
|
||||
out_imgs.append(tf(img))
|
||||
out_imgs.append(tf(img_tensor))
|
||||
|
||||
for i in range(1, len(out_imgs)):
|
||||
with pytest.raises(AssertionError):
|
||||
torch.testing.assert_close(out_imgs[0], out_imgs[i])
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"transform, min_max_values",
|
||||
[
|
||||
@@ -158,21 +149,24 @@ def test_get_image_transforms_random_order(img):
|
||||
("sharpness", [(0.5, 0.5), (2.0, 2.0)]),
|
||||
],
|
||||
)
|
||||
def test_backward_compatibility_torchvision(transform, min_max_values, img, single_transforms):
|
||||
def test_backward_compatibility_torchvision(img_tensor_factory, transform, min_max_values, single_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
for min_max in min_max_values:
|
||||
kwargs = {
|
||||
f"{transform}_weight": 1.0,
|
||||
f"{transform}_min_max": min_max,
|
||||
}
|
||||
tf = get_image_transforms(**kwargs)
|
||||
actual = tf(img)
|
||||
actual = tf(img_tensor)
|
||||
key = f"{transform}_{min_max[0]}_{min_max[1]}"
|
||||
expected = single_transforms[key]
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@require_x86_64_kernel
|
||||
def test_backward_compatibility_default_config(img, default_transforms):
|
||||
def test_backward_compatibility_default_config(img_tensor_factory, default_transforms):
|
||||
img_tensor = img_tensor_factory()
|
||||
cfg = init_hydra_config(DEFAULT_CONFIG_PATH)
|
||||
cfg_tf = cfg.training.image_transforms
|
||||
default_tf = get_image_transforms(
|
||||
@@ -191,7 +185,7 @@ def test_backward_compatibility_default_config(img, default_transforms):
|
||||
)
|
||||
|
||||
with seeded_context(1337):
|
||||
actual = default_tf(img)
|
||||
actual = default_tf(img_tensor)
|
||||
|
||||
expected = default_transforms["default"]
|
||||
|
||||
@@ -199,33 +193,36 @@ def test_backward_compatibility_default_config(img, default_transforms):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p", [[0, 1], [1, 0]])
|
||||
def test_random_subset_apply_single_choice(p, img):
|
||||
def test_random_subset_apply_single_choice(img_tensor_factory, p):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_choice = RandomSubsetApply(flips, p=p, n_subset=1, random_order=False)
|
||||
actual = random_choice(img)
|
||||
actual = random_choice(img_tensor)
|
||||
|
||||
p_horz, _ = p
|
||||
if p_horz:
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img))
|
||||
torch.testing.assert_close(actual, F.horizontal_flip(img_tensor))
|
||||
else:
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img))
|
||||
torch.testing.assert_close(actual, F.vertical_flip(img_tensor))
|
||||
|
||||
|
||||
def test_random_subset_apply_random_order(img):
|
||||
def test_random_subset_apply_random_order(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
flips = [v2.RandomHorizontalFlip(p=1), v2.RandomVerticalFlip(p=1)]
|
||||
random_order = RandomSubsetApply(flips, p=[0.5, 0.5], n_subset=2, random_order=True)
|
||||
# We can't really check whether the transforms are actually applied in random order. However,
|
||||
# horizontal and vertical flip are commutative. Meaning, even under the assumption that the transform
|
||||
# applies them in random order, we can use a fixed order to compute the expected value.
|
||||
actual = random_order(img)
|
||||
expected = v2.Compose(flips)(img)
|
||||
actual = random_order(img_tensor)
|
||||
expected = v2.Compose(flips)(img_tensor)
|
||||
torch.testing.assert_close(actual, expected)
|
||||
|
||||
|
||||
def test_random_subset_apply_valid_transforms(color_jitters, img):
|
||||
def test_random_subset_apply_valid_transforms(img_tensor_factory, color_jitters):
|
||||
img_tensor = img_tensor_factory()
|
||||
transform = RandomSubsetApply(color_jitters)
|
||||
output = transform(img)
|
||||
assert output.shape == img.shape
|
||||
output = transform(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_random_subset_apply_probability_length_mismatch(color_jitters):
|
||||
@@ -239,16 +236,18 @@ def test_random_subset_apply_invalid_n_subset(color_jitters, n_subset):
|
||||
RandomSubsetApply(color_jitters, n_subset=n_subset)
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_tuple(img):
|
||||
def test_sharpness_jitter_valid_range_tuple(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter((0.1, 2.0))
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_valid_range_float(img):
|
||||
def test_sharpness_jitter_valid_range_float(img_tensor_factory):
|
||||
img_tensor = img_tensor_factory()
|
||||
tf = SharpnessJitter(0.5)
|
||||
output = tf(img)
|
||||
assert output.shape == img.shape
|
||||
output = tf(img_tensor)
|
||||
assert output.shape == img_tensor.shape
|
||||
|
||||
|
||||
def test_sharpness_jitter_invalid_range_min_negative():
|
||||
@@ -261,6 +260,7 @@ def test_sharpness_jitter_invalid_range_max_smaller():
|
||||
SharpnessJitter((2.0, 0.1))
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id, n_examples",
|
||||
[
|
||||
|
||||
359
tests/test_image_writer.py
Normal file
359
tests/test_image_writer.py
Normal file
@@ -0,0 +1,359 @@
|
||||
import queue
|
||||
import time
|
||||
from multiprocessing import queues
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from lerobot.common.datasets.image_writer import (
|
||||
AsyncImageWriter,
|
||||
image_array_to_image,
|
||||
safe_stop_image_writer,
|
||||
write_image,
|
||||
)
|
||||
|
||||
DUMMY_IMAGE = "test_image.png"
|
||||
|
||||
|
||||
def test_init_threading():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 0
|
||||
assert writer.num_threads == 2
|
||||
assert isinstance(writer.queue, queue.Queue)
|
||||
assert len(writer.threads) == 2
|
||||
assert len(writer.processes) == 0
|
||||
assert all(t.is_alive() for t in writer.threads)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_init_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
assert writer.num_processes == 2
|
||||
assert writer.num_threads == 2
|
||||
assert isinstance(writer.queue, queues.JoinableQueue)
|
||||
assert len(writer.threads) == 0
|
||||
assert len(writer.processes) == 2
|
||||
assert all(p.is_alive() for p in writer.processes)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_zero_threads():
|
||||
with pytest.raises(ValueError):
|
||||
AsyncImageWriter(num_processes=0, num_threads=0)
|
||||
|
||||
|
||||
def test_image_array_to_image_rgb(img_array_factory):
|
||||
img_array = img_array_factory(100, 100)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
def test_image_array_to_image_pytorch_format(img_array_factory):
|
||||
img_array = img_array_factory(100, 100).transpose(2, 0, 1)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO: implement")
|
||||
def test_image_array_to_image_single_channel(img_array_factory):
|
||||
img_array = img_array_factory(channels=1)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "L"
|
||||
|
||||
|
||||
def test_image_array_to_image_float_array(img_array_factory):
|
||||
img_array = img_array_factory(dtype=np.float32)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
|
||||
|
||||
def test_image_array_to_image_out_of_bounds_float():
|
||||
# Float array with values out of [0, 1]
|
||||
img_array = np.random.uniform(-1, 2, size=(100, 100, 3)).astype(np.float32)
|
||||
result_image = image_array_to_image(img_array)
|
||||
assert isinstance(result_image, Image.Image)
|
||||
assert result_image.size == (100, 100)
|
||||
assert result_image.mode == "RGB"
|
||||
assert np.array(result_image).dtype == np.uint8
|
||||
assert np.array(result_image).min() >= 0 and np.array(result_image).max() <= 255
|
||||
|
||||
|
||||
def test_write_image_numpy(tmp_path, img_array_factory):
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
|
||||
|
||||
def test_write_image_image(tmp_path, img_factory):
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
write_image(image_pil, fpath)
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
assert np.array_equal(image_pil, saved_image)
|
||||
|
||||
|
||||
def test_write_image_exception(tmp_path):
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
with patch("builtins.print") as mock_print:
|
||||
write_image(image_array, fpath)
|
||||
mock_print.assert_called()
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_save_image_numpy(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_numpy_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(image_array, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_torch(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_tensor, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_torch_multiprocessing(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_tensor, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
expected_image = (image_tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert np.array_equal(expected_image, saved_image)
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_pil(tmp_path, img_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_pil, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_pil_multiprocessing(tmp_path, img_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
image_pil = img_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_pil, fpath)
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
saved_image = Image.open(fpath)
|
||||
assert list(saved_image.getdata()) == list(image_pil.getdata())
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_invalid_data(tmp_path):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = "invalid data"
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
with patch("builtins.print") as mock_print:
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
mock_print.assert_called()
|
||||
assert not fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_save_image_after_stop(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
image_array = img_array_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
writer.save_image(image_array, fpath)
|
||||
time.sleep(1)
|
||||
assert not fpath.exists()
|
||||
|
||||
|
||||
def test_stop():
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_stop_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
assert not any(p.is_alive() for p in writer.processes)
|
||||
|
||||
|
||||
def test_multiple_stops():
|
||||
writer = AsyncImageWriter()
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_multiple_stops_multiprocessing():
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
writer.stop()
|
||||
writer.stop() # Should not raise an exception
|
||||
assert not any(t.is_alive() for t in writer.threads)
|
||||
|
||||
|
||||
def test_wait_until_done(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=0, num_threads=4)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory(height=500, width=500) for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
for i, fpath in enumerate(fpaths):
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(saved_image, image_arrays[i])
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_wait_until_done_multiprocessing(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter(num_processes=2, num_threads=2)
|
||||
try:
|
||||
num_images = 100
|
||||
image_arrays = [img_array_factory() for _ in range(num_images)]
|
||||
fpaths = [tmp_path / f"frame_{i:06d}.png" for i in range(num_images)]
|
||||
for image_array, fpath in zip(image_arrays, fpaths, strict=True):
|
||||
fpath.parent.mkdir(parents=True, exist_ok=True)
|
||||
writer.save_image(image_array, fpath)
|
||||
writer.wait_until_done()
|
||||
for i, fpath in enumerate(fpaths):
|
||||
assert fpath.exists()
|
||||
saved_image = np.array(Image.open(fpath))
|
||||
assert np.array_equal(saved_image, image_arrays[i])
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_exception_handling(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
with (
|
||||
patch.object(writer.queue, "put", side_effect=queue.Full("Queue is full")),
|
||||
pytest.raises(queue.Full) as exc_info,
|
||||
):
|
||||
writer.save_image(image_array, tmp_path / "test.png")
|
||||
assert str(exc_info.value) == "Queue is full"
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_with_different_image_formats(tmp_path, img_array_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_array = img_array_factory()
|
||||
formats = ["png", "jpeg", "bmp"]
|
||||
for fmt in formats:
|
||||
fpath = tmp_path / f"test_image.{fmt}"
|
||||
write_image(image_array, fpath)
|
||||
assert fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
|
||||
|
||||
def test_safe_stop_image_writer_decorator():
|
||||
class MockDataset:
|
||||
def __init__(self):
|
||||
self.image_writer = MagicMock(spec=AsyncImageWriter)
|
||||
|
||||
@safe_stop_image_writer
|
||||
def function_that_raises_exception(dataset=None):
|
||||
raise Exception("Test exception")
|
||||
|
||||
dataset = MockDataset()
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
function_that_raises_exception(dataset=dataset)
|
||||
|
||||
assert str(exc_info.value) == "Test exception"
|
||||
dataset.image_writer.stop.assert_called_once()
|
||||
|
||||
|
||||
def test_main_process_time(tmp_path, img_tensor_factory):
|
||||
writer = AsyncImageWriter()
|
||||
try:
|
||||
image_tensor = img_tensor_factory()
|
||||
fpath = tmp_path / DUMMY_IMAGE
|
||||
start_time = time.perf_counter()
|
||||
writer.save_image(image_tensor, fpath)
|
||||
end_time = time.perf_counter()
|
||||
time_spent = end_time - start_time
|
||||
# Might need to adjust this threshold depending on hardware
|
||||
assert time_spent < 0.01, f"Main process time exceeded threshold: {time_spent}s"
|
||||
writer.wait_until_done()
|
||||
assert fpath.exists()
|
||||
finally:
|
||||
writer.stop()
|
||||
@@ -19,11 +19,8 @@ from uuid import uuid4
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.common.datasets.online_buffer import OnlineBuffer, compute_sampler_weights
|
||||
from lerobot.common.datasets.utils import hf_transform_to_torch
|
||||
|
||||
# Some constants for OnlineBuffer tests.
|
||||
data_key = "data"
|
||||
@@ -212,29 +209,17 @@ def test_delta_timestamps_outside_tolerance_outside_episode_range():
|
||||
|
||||
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
@pytest.mark.parametrize("offline_dataset_size", [0, 6])
|
||||
@pytest.mark.parametrize("offline_dataset_size", [1, 6])
|
||||
@pytest.mark.parametrize("online_dataset_size", [0, 4])
|
||||
@pytest.mark.parametrize("online_sampling_ratio", [0.0, 1.0])
|
||||
def test_compute_sampler_weights_trivial(
|
||||
offline_dataset_size: int, online_dataset_size: int, online_sampling_ratio: float
|
||||
lerobot_dataset_factory,
|
||||
tmp_path,
|
||||
offline_dataset_size: int,
|
||||
online_dataset_size: int,
|
||||
online_sampling_ratio: float,
|
||||
):
|
||||
# Pass/skip the test if both datasets sizes are zero.
|
||||
if offline_dataset_size + online_dataset_size == 0:
|
||||
return
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(
|
||||
hf_dataset=Dataset.from_dict({"data": list(range(offline_dataset_size))})
|
||||
)
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
if offline_dataset_size == 0:
|
||||
offline_dataset.episode_data_index = {}
|
||||
else:
|
||||
# Set up an episode_data_index with at least two episodes.
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, offline_dataset_size // 2]),
|
||||
"to": torch.tensor([offline_dataset_size // 2, offline_dataset_size]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=offline_dataset_size)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
if online_dataset_size > 0:
|
||||
online_dataset.add_data(
|
||||
@@ -254,16 +239,9 @@ def test_compute_sampler_weights_trivial(
|
||||
assert torch.allclose(weights, expected_weights)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio():
|
||||
def test_compute_sampler_weights_nontrivial_ratio(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0, 2]),
|
||||
"to": torch.tensor([2, 4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
online_sampling_ratio = 0.8
|
||||
@@ -275,16 +253,9 @@ def test_compute_sampler_weights_nontrivial_ratio():
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
||||
def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n(lerobot_dataset_factory, tmp_path):
|
||||
# Arbitrarily set small dataset sizes, making sure to have uneven sizes.
|
||||
# Create spoof offline dataset.
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict({"data": list(range(4))}))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {
|
||||
"from": torch.tensor([0]),
|
||||
"to": torch.tensor([4]),
|
||||
}
|
||||
# Create spoof online datset.
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=4)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
weights = compute_sampler_weights(
|
||||
@@ -295,18 +266,9 @@ def test_compute_sampler_weights_nontrivial_ratio_and_drop_last_n():
|
||||
)
|
||||
|
||||
|
||||
def test_compute_sampler_weights_drop_n_last_frames():
|
||||
def test_compute_sampler_weights_drop_n_last_frames(lerobot_dataset_factory, tmp_path):
|
||||
"""Note: test copied from test_sampler."""
|
||||
data_dict = {
|
||||
"timestamp": [0, 0.1],
|
||||
"index": [0, 1],
|
||||
"episode_index": [0, 0],
|
||||
"frame_index": [0, 1],
|
||||
}
|
||||
offline_dataset = LeRobotDataset.from_preloaded(hf_dataset=Dataset.from_dict(data_dict))
|
||||
offline_dataset.hf_dataset.set_transform(hf_transform_to_torch)
|
||||
offline_dataset.episode_data_index = {"from": torch.tensor([0]), "to": torch.tensor([2])}
|
||||
|
||||
offline_dataset = lerobot_dataset_factory(tmp_path, total_episodes=1, total_frames=2)
|
||||
online_dataset, _ = make_new_buffer()
|
||||
online_dataset.add_data(make_spoof_data_frames(n_episodes=4, n_frames_per_episode=2))
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ def test_get_policy_and_config_classes(policy_name: str):
|
||||
assert issubclass(config_cls, inspect.signature(policy_cls.__init__).parameters["config"].annotation)
|
||||
|
||||
|
||||
# TODO(aliberts): refactor using lerobot/__init__.py variables
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name,policy_name,extra_overrides",
|
||||
[
|
||||
@@ -136,7 +136,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
|
||||
# Check that we can make the policy object.
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
# Check that the policy follows the required protocol.
|
||||
assert isinstance(
|
||||
policy, Policy
|
||||
@@ -195,6 +195,7 @@ def test_policy(env_name, policy_name, extra_overrides):
|
||||
env.step(action)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
def test_act_backbone_lr():
|
||||
"""
|
||||
Test that the ACT policy can be instantiated with a different learning rate for the backbone.
|
||||
@@ -213,7 +214,7 @@ def test_act_backbone_lr():
|
||||
assert cfg.training.lr_backbone == 0.001
|
||||
|
||||
dataset = make_dataset(cfg)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.stats)
|
||||
policy = make_policy(hydra_cfg=cfg, dataset_stats=dataset.meta.stats)
|
||||
optimizer, _ = make_optimizer_and_scheduler(cfg, policy)
|
||||
assert len(optimizer.param_groups) == 2
|
||||
assert optimizer.param_groups[0]["lr"] == cfg.training.lr
|
||||
@@ -351,6 +352,7 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"env_name, policy_name, extra_overrides, file_name_extra",
|
||||
[
|
||||
@@ -381,7 +383,7 @@ def test_backward_compatibility(env_name, policy_name, extra_overrides, file_nam
|
||||
include a report on what changed and how that affected the outputs.
|
||||
2. Go to the `if __name__ == "__main__"` block of `tests/scripts/save_policy_to_safetensors.py` and
|
||||
add the policies you want to update the test artifacts for.
|
||||
3. Run `DATA_DIR=tests/data python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
3. Run `python tests/scripts/save_policy_to_safetensors.py`. The test artifact
|
||||
should be updated.
|
||||
4. Check that this test now passes.
|
||||
5. Remember to restore `tests/scripts/save_policy_to_safetensors.py` to its original state.
|
||||
|
||||
@@ -5,7 +5,7 @@ we skip them for now in our CI.
|
||||
|
||||
Example to run backward compatiblity tests locally:
|
||||
```
|
||||
DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -250,6 +250,7 @@ def test_push_dataset_to_hub_out_dir_force_override_false(tmpdir):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip("TODO after v2 migration / removing hydra")
|
||||
@pytest.mark.parametrize(
|
||||
"required_packages, raw_format, repo_id, make_test_data",
|
||||
[
|
||||
@@ -329,7 +330,7 @@ def test_push_dataset_to_hub_format(required_packages, tmpdir, raw_format, repo_
|
||||
],
|
||||
)
|
||||
@pytest.mark.skip(
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `DATA_DIR=tests/data python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
"Not compatible with our CI since it downloads raw datasets. Run with `python -m pytest --run-skipped tests/test_push_dataset_to_hub.py::test_push_dataset_to_hub_pusht_backward_compatibility`"
|
||||
)
|
||||
def test_push_dataset_to_hub_pusht_backward_compatibility(tmpdir, raw_format, repo_id):
|
||||
_, dataset_id = repo_id.split("/")
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
# limitations under the License.
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,10 +7,9 @@ import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index
|
||||
from lerobot.common.datasets.utils import (
|
||||
calculate_episode_data_index,
|
||||
hf_transform_to_torch,
|
||||
reset_episode_index,
|
||||
)
|
||||
from lerobot.common.utils.utils import (
|
||||
get_global_random_state,
|
||||
@@ -73,20 +72,6 @@ def test_calculate_episode_data_index():
|
||||
assert torch.equal(episode_data_index["to"], torch.tensor([2, 3, 6]))
|
||||
|
||||
|
||||
def test_reset_episode_index():
|
||||
dataset = Dataset.from_dict(
|
||||
{
|
||||
"timestamp": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
|
||||
"index": [0, 1, 2, 3, 4, 5],
|
||||
"episode_index": [10, 10, 11, 12, 12, 12],
|
||||
},
|
||||
)
|
||||
dataset.set_transform(hf_transform_to_torch)
|
||||
correct_episode_index = [0, 0, 1, 2, 2, 2]
|
||||
dataset = reset_episode_index(dataset)
|
||||
assert dataset["episode_index"] == correct_episode_index
|
||||
|
||||
|
||||
def test_init_hydra_config_empty():
|
||||
test_file = f"/tmp/test_init_hydra_config_empty_{uuid4().hex}.yaml"
|
||||
with open(test_file, "w") as f:
|
||||
|
||||
@@ -13,25 +13,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset import visualize_dataset
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
@pytest.mark.parametrize("root", [Path(__file__).parent / "data"])
|
||||
def test_visualize_local_dataset(tmpdir, repo_id, root):
|
||||
@pytest.mark.skip("TODO: add dummy videos")
|
||||
def test_visualize_local_dataset(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
rrd_path = visualize_dataset(
|
||||
repo_id,
|
||||
dataset,
|
||||
episode_index=0,
|
||||
batch_size=32,
|
||||
save=True,
|
||||
output_dir=tmpdir,
|
||||
root=root,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
assert rrd_path.exists()
|
||||
|
||||
@@ -14,23 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.scripts.visualize_dataset_html import visualize_dataset_html
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"repo_id",
|
||||
["lerobot/pusht"],
|
||||
)
|
||||
def test_visualize_dataset_html(tmpdir, repo_id):
|
||||
tmpdir = Path(tmpdir)
|
||||
def test_visualize_dataset_html(tmp_path, lerobot_dataset_factory):
|
||||
root = tmp_path / "dataset"
|
||||
output_dir = tmp_path / "outputs"
|
||||
dataset = lerobot_dataset_factory(root=root)
|
||||
visualize_dataset_html(
|
||||
repo_id,
|
||||
dataset,
|
||||
episodes=[0],
|
||||
output_dir=tmpdir,
|
||||
output_dir=output_dir,
|
||||
serve=False,
|
||||
)
|
||||
assert (tmpdir / "static" / "episode_0.csv").exists()
|
||||
assert (output_dir / "static" / "episode_0.csv").exists()
|
||||
|
||||
Reference in New Issue
Block a user